From 282903bb6798f49af66f6935ee4dc0015895cf7c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 15 Oct 2023 09:41:02 +0300 Subject: repair unload sd checkpoint button --- modules/api/api.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index efedafa4..09083874 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -541,12 +540,12 @@ class Api: return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -566,7 +565,7 @@ class Api: def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): -- cgit v1.2.3