aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.yml67
-rw-r--r--configs/alt-diffusion-m18-inference.yaml73
-rw-r--r--extensions-builtin/Lora/network_glora.py33
-rw-r--r--extensions-builtin/Lora/networks.py2
-rw-r--r--javascript/imageviewer.js7
-rw-r--r--modules/api/api.py3
-rw-r--r--modules/processing.py5
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/sd_models.py8
-rw-r--r--modules/sd_models_config.py5
-rw-r--r--modules/shared_options.py2
-rw-r--r--modules/ui.py2
-rw-r--r--modules/xlmr_m18.py164
13 files changed, 345 insertions, 30 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index cf6a2be8..5876e941 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -1,25 +1,45 @@
name: Bug Report
-description: You think somethings is broken in the UI
+description: You think something is broken in the UI
title: "[Bug]: "
labels: ["bug-report"]
body:
+ - type: markdown
+ attributes:
+ value: |
+ > The title of the bug report should be short and descriptive.
+ > Use relevant keywords for searchability.
+ > Do not leave it blank, but also do not put an entire error log in it.
- type: checkboxes
attributes:
- label: Is there an existing issue for this?
- description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
+ label: Checklist
+ description: |
+ Please perform basic debugging to see if extensions or configuration is the cause of the issue.
+ Basic debug procedure
+  1. Disable all third-party extensions - check if extension is the cause
+  2. Update extensions and webui - sometimes things just need to be updated
+  3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
+  4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
+  5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
+ Before making a issue report please, check that the issue hasn't been reported recently.
options:
- - label: I have searched the existing issues and checked the recent builds/commits
- required: true
+ - label: The issue exists after disabling all extensions
+ - label: The issue exists on a clean installation of webui
+ - label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
+ - label: The issue exists in the current version of the webui
+ - label: The issue has not been reported before recently
+ - label: The issue has been reported before but has not been fixed yet
- type: markdown
attributes:
value: |
- *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
+ > Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
- type: textarea
id: what-did
attributes:
label: What happened?
description: Tell us what happened in a very clear and simple way
+ placeholder: |
+ txt2img is not working as intended.
validations:
required: true
- type: textarea
@@ -27,9 +47,9 @@ body:
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step instructions on how to reproduce the bug
- value: |
- 1. Go to ....
- 2. Press ....
+ placeholder: |
+ 1. Go to ...
+ 2. Press ...
3. ...
validations:
required: true
@@ -38,13 +58,8 @@ body:
attributes:
label: What should have happened?
description: Tell us what you think the normal behavior should be
- validations:
- required: true
- - type: textarea
- id: sysinfo
- attributes:
- label: Sysinfo
- description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
+ placeholder: |
+ WebUI should ...
validations:
required: true
- type: dropdown
@@ -58,12 +73,25 @@ body:
- Brave
- Apple Safari
- Microsoft Edge
+ - Android
+ - iOS
- Other
- type: textarea
+ id: sysinfo
+ attributes:
+ label: Sysinfo
+ description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
+ placeholder: |
+ 1. Go to WebUI Settings -> Sysinfo -> Download system info.
+ If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
+ 2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
+ validations:
+ required: true
+ - type: textarea
id: logs
attributes:
label: Console logs
- description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
+ description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
render: Shell
validations:
required: true
@@ -71,4 +99,7 @@ body:
id: misc
attributes:
label: Additional information
- description: Please provide us with any relevant additional info or context.
+ description: |
+ Please provide us with any relevant additional info or context.
+ Examples:
+  I have updated my GPU driver recently.
diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml
new file mode 100644
index 00000000..41a031d5
--- /dev/null
+++ b/configs/alt-diffusion-m18-inference.yaml
@@ -0,0 +1,73 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: modules.xlmr_m18.BertSeriesModelWithTransformation
+ params:
+ name: "XLMR-Large"
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
new file mode 100644
index 00000000..492d4870
--- /dev/null
+++ b/extensions-builtin/Lora/network_glora.py
@@ -0,0 +1,33 @@
+
+import network
+
+class ModuleTypeGLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
+ return NetworkModuleGLora(net, weights)
+
+ return None
+
+# adapted from https://github.com/KohakuBlueleaf/LyCORIS
+class NetworkModuleGLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["a1.weight"]
+ self.w1b = weights.w["b1.weight"]
+ self.w2a = weights.w["a2.weight"]
+ self.w2b = weights.w["b2.weight"]
+
+ def calc_updown(self, orig_weight):
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w1a.size(0), w1b.size(1)]
+ updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index d5f0f9f1..60d8dec4 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -5,6 +5,7 @@ import re
import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
@@ -26,6 +27,7 @@ module_types = [
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
]
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index c21d396e..e4dae91b 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
-
- if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
+ let preview = gradioApp().querySelectorAll('.livePreview > img');
+ if (preview.length > 0) {
+ // show preview image if available
+ modalImage.src = preview[preview.length - 1].src;
+ } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
const modal = gradioApp().getElementById("lightboxModal");
diff --git a/modules/api/api.py b/modules/api/api.py
index 905ef9c9..efedafa4 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -103,7 +103,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
-
+ if isinstance(image, str):
+ return image
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
diff --git a/modules/processing.py b/modules/processing.py
index 36bc94f7..40598f5c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -711,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.before_process(p)
- stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data}
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@@ -960,6 +960,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.nextjob()
+ if not infotexts:
+ infotexts.append(Processed(p, []).infotext(p, 0))
+
p.color_corrections = None
index_of_first_image = 0
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 22a1eb5c..bc5fbcd3 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -211,7 +211,7 @@ class StableDiffusionModelHijack:
else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7f8502f5..c8efeedc 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -357,12 +357,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
-
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 08dd03f1..deab2f6e 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
+config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
@@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+ if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
+ return config_alt_diffusion_m18
return config_alt_diffusion
return config_default
diff --git a/modules/shared_options.py b/modules/shared_options.py
index ab9b0072..ce395302 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -62,6 +62,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
+
+ "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
diff --git a/modules/ui.py b/modules/ui.py
index 3d1f5285..bcf39199 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1296,7 +1296,7 @@ def create_ui():
loadsave.setup_ui()
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
new file mode 100644
index 00000000..a727e865
--- /dev/null
+++ b/modules/xlmr_m18.py
@@ -0,0 +1,164 @@
+from transformers import BertPreTrainedModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 1024
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ # self.pooler = lambda x: x[:,0]
+ # self.post_init()
+
+ self.has_pre_transformation = True
+ if self.has_pre_transformation:
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # # last module outputs
+ # sequence_output = outputs[0]
+
+
+ # # project every module
+ # sequence_output_ln = self.pre_LN(sequence_output)
+
+ # # pooler
+ # pooler_output = self.pooler(sequence_output_ln)
+ # pooler_output = self.transformation(pooler_output)
+ # projection_state = self.transformation(outputs.last_hidden_state)
+
+ if self.has_pre_transformation:
+ sequence_output2 = outputs["hidden_states"][-2]
+ sequence_output2 = self.pre_LN(sequence_output2)
+ projection_state2 = self.transformation_pre(sequence_output2)
+
+ return {
+ "projection_state": projection_state2,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+ else:
+ projection_state = self.transformation(outputs.last_hidden_state)
+ return {
+ "projection_state": projection_state,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+
+
+ # return {
+ # 'pooler_output':pooler_output,
+ # 'last_hidden_state':outputs.last_hidden_state,
+ # 'hidden_states':outputs.hidden_states,
+ # 'attentions':outputs.attentions,
+ # 'projection_state':projection_state,
+ # 'sequence_out': sequence_output
+ # }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig