English
John6666 commited on
Commit
1d43287
·
verified ·
1 Parent(s): 428b5a8

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -10,14 +10,15 @@ torch._dynamo.config.suppress_errors = True
10
 
11
  #from huggingface_inference_toolkit.logging import logger
12
 
13
- def compile_pipeline(pipe):
 
14
  pipe.transformer.to(memory_format=torch.channels_last)
15
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
16
  return pipe
17
 
18
  class EndpointHandler:
19
  def __init__(self, path="", **kwargs: Any) -> None: # type: ignore
20
- is_compile = False
21
  #repo_id = "camenduru/FLUX.1-dev-diffusers"
22
  repo_id = "NoMoreCopyright/FLUX.1-dev-test"
23
  dtype = torch.bfloat16
 
10
 
11
  #from huggingface_inference_toolkit.logging import logger
12
 
13
+ def compile_pipeline(pipe) -> Any:
14
+ pipe.transformer.fuse_qkv_projections()
15
  pipe.transformer.to(memory_format=torch.channels_last)
16
  pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
17
  return pipe
18
 
19
  class EndpointHandler:
20
  def __init__(self, path="", **kwargs: Any) -> None: # type: ignore
21
+ is_compile = True
22
  #repo_id = "camenduru/FLUX.1-dev-diffusers"
23
  repo_id = "NoMoreCopyright/FLUX.1-dev-test"
24
  dtype = torch.bfloat16