aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-30 06:10:22 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-30 06:10:22 +0000
commit149784202cca8612b43629c601ee27cfda64e623 (patch)
tree5e6121de8c0b0e935159f01fdedbb7b8a221f8b8 /modules
parent060ee5d3a7ba258f944d0c891f90ac4a65411446 (diff)
downloadstable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.tar.gz
stable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.tar.bz2
stable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.zip
rework #3722 to not introduce duplicate code
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py43
-rw-r--r--modules/shared.py22
2 files changed, 32 insertions, 33 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5c5b210f..6c06d449 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-# copy from wrap_gradio_gpu_call of webui.py
-# because queue lock will be acquired in api handlers
-# and time start needs to be set
-# the function has been modified into two parts
-
-def before_gpu_call():
- devices.torch_gc()
-
- shared.state.sampling_step = 0
- shared.state.job_count = -1
- shared.state.job_no = 0
- shared.state.job_timestamp = shared.state.get_job_timestamp()
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.current_image_sampling_step = 0
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.textinfo = None
- shared.state.time_start = time.time()
-
-def after_gpu_call():
- shared.state.job = ""
- shared.state.job_count = 0
-
- devices.torch_gc()
def upscaler_to_index(name: str):
try:
@@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -78,10 +56,13 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -119,11 +100,13 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
- # Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/shared.py b/modules/shared.py
index f7b0990c..e4f163c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -144,9 +144,6 @@ class State:
self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
-
def dict(self):
obj = {
"skipped": self.skipped,
@@ -160,6 +157,25 @@ class State:
return obj
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
state = State()