aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
authorΦφ <42910943+Brawlence@users.noreply.github.com>2023-03-09 04:56:19 +0000
committerΦφ <42910943+Brawlence@users.noreply.github.com>2023-03-21 06:28:50 +0000
commit4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1 (patch)
tree943c119f18a0aa7a748f95c989f73e3a9a69a1c3 /modules
parenta9fed7c364061ae6efb37f797b6b522cb3cf7aa2 (diff)
downloadstable-diffusion-webui-gfx803-4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1.tar.gz
stable-diffusion-webui-gfx803-4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1.tar.bz2
stable-diffusion-webui-gfx803-4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1.zip
Unload checkpoints on Request
…to free VRAM. New Action buttons in the settings to manually free and reload checkpoints, essentially juggling models between RAM and VRAM.
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py14
-rw-r--r--modules/sd_models.py22
-rw-r--r--modules/ui.py22
3 files changed, 56 insertions, 2 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 35e17afc..f52f7fef 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -150,6 +150,8 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs):
@@ -412,6 +414,16 @@ class Api:
return {}
+ def unloadapi(self):
+ unload_model_weights()
+
+ return {}
+
+ def reloadapi(self):
+ reload_model_weights()
+
+ return {}
+
def skip(self):
shared.state.skip()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f0cb1240..f9dd0521 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
try:
@@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
return sd_model
+
+def unload_model_weights(sd_model=None, info=None):
+ from modules import lowvram, devices, sd_hijack
+ timer = Timer()
+
+ if shared.sd_model:
+
+ # shared.sd_model.cond_stage_model.to(devices.cpu)
+ # shared.sd_model.first_stage_model.to(devices.cpu)
+ shared.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ sd_model = None
+ gc.collect()
+ devices.torch_gc()
+ torch.cuda.empty_cache()
+
+ print(f"Unloaded weights {timer.summary()}.")
+
+ return sd_model \ No newline at end of file
diff --git a/modules/ui.py b/modules/ui.py
index 7e603332..d93ef134 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1491,11 +1491,33 @@ def create_ui():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ with gr.Row():
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+
+ def unload_sd_weights():
+ modules.sd_models.unload_model_weights()
+
+ def reload_sd_weights():
+ modules.sd_models.reload_model_weights()
+
+ unload_sd_model.click(
+ fn=unload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ reload_sd_model.click(
+ fn=reload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
request_notifications.click(
fn=lambda: None,