diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-30 06:10:22 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-30 06:10:22 +0000 |
commit | 149784202cca8612b43629c601ee27cfda64e623 (patch) | |
tree | 5e6121de8c0b0e935159f01fdedbb7b8a221f8b8 /modules | |
parent | 060ee5d3a7ba258f944d0c891f90ac4a65411446 (diff) | |
download | stable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.tar.gz stable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.tar.bz2 stable-diffusion-webui-gfx803-149784202cca8612b43629c601ee27cfda64e623.zip |
rework #3722 to not introduce duplicate code
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 43 | ||||
-rw-r--r-- | modules/shared.py | 22 |
2 files changed, 32 insertions, 33 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 5c5b210f..6c06d449 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo -# copy from wrap_gradio_gpu_call of webui.py -# because queue lock will be acquired in api handlers -# and time start needs to be set -# the function has been modified into two parts - -def before_gpu_call(): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None - shared.state.time_start = time.time() - -def after_gpu_call(): - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() def upscaler_to_index(name: str): try: @@ -41,8 +16,10 @@ def upscaler_to_index(name: str): except: raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) @@ -51,6 +28,7 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -78,10 +56,13 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -119,11 +100,13 @@ class Api: imgs = [img] * p.batch_size p.init_images = imgs - # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/shared.py b/modules/shared.py index f7b0990c..e4f163c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -144,9 +144,6 @@ class State: self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
-
def dict(self):
obj = {
"skipped": self.skipped,
@@ -160,6 +157,25 @@ class State: return obj
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
state = State()
|