diff options
author | arcticfaded <jbelt021@fiu.edu> | 2022-10-17 19:10:36 +0000 |
---|---|---|
committer | arcticfaded <jbelt021@fiu.edu> | 2022-10-17 19:10:36 +0000 |
commit | f80e914ac4aa69a9783b4040813253500b34d925 (patch) | |
tree | 0c738fc85532cfddc1c2888ef4a682a452b5ddac /modules/api/processing.py | |
parent | d42125baf62880854ad06af06c15c23e7e50cca6 (diff) | |
download | stable-diffusion-webui-gfx803-f80e914ac4aa69a9783b4040813253500b34d925.tar.gz stable-diffusion-webui-gfx803-f80e914ac4aa69a9783b4040813253500b34d925.tar.bz2 stable-diffusion-webui-gfx803-f80e914ac4aa69a9783b4040813253500b34d925.zip |
example API working with gradio
Diffstat (limited to 'modules/api/processing.py')
-rw-r--r-- | modules/api/processing.py | 56 |
1 files changed, 38 insertions, 18 deletions
diff --git a/modules/api/processing.py b/modules/api/processing.py index e4df93c5..b6798241 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu import inspect +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + class ModelDef(BaseModel): """Assistance Class for Pydantic Dynamic Model Generation""" @@ -14,7 +32,7 @@ class ModelDef(BaseModel): field_value: Any -class pydanticModelGenerator: +class PydanticModelGenerator: """ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: source_data is a snapshot of the default values produced by the class @@ -24,30 +42,33 @@ class pydanticModelGenerator: def __init__( self, model_name: str = None, - source_data: {} = {}, - params: Dict = {}, - overrides: Dict = {}, - optionals: Dict = {}, + class_instance = None ): - def field_type_generator(k, v, overrides, optionals): - field_type = str if not overrides.get(k) else overrides[k]["type"] - if v is None: - field_type = Any - else: - field_type = type(v) + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation return Optional[field_type] + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + self._model_name = model_name - self._json_data = source_data + self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( field=underscore(k), field_alias=k, - field_type=field_type_generator(k, v, overrides, optionals), - field_value=v + field_type=field_type_generator(k, v), + field_value=v.default ) - for (k,v) in source_data.items() if k in params + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] def generate_model(self): @@ -60,8 +81,7 @@ class pydanticModelGenerator: } DynamicModel = create_model(self._model_name, **fields) DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", - StableDiffusionProcessing().__dict__, - inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() +StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() |