DrishtiSharma commited on
Commit
4900612
Β·
verified Β·
1 Parent(s): 3fb2372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -64,22 +64,16 @@ layer_agent_config_rec = {
64
  }
65
 
66
 
67
- def stream_response(messages: Iterable[ResponseChunk]):
68
  layer_outputs = {}
69
- lock = threading.Lock()
70
-
71
- def safe_append(layer, delta):
72
- with lock:
73
- if layer not in layer_outputs:
74
- layer_outputs[layer] = []
75
- layer_outputs[layer].append(delta)
76
 
77
- for message in messages:
78
  if message['response_type'] == 'intermediate':
79
  layer = message['metadata']['layer']
80
- safe_append(layer, message['delta'])
 
 
81
  else:
82
- # Display accumulated layer outputs
83
  for layer, outputs in layer_outputs.items():
84
  st.write(f"Layer {layer}")
85
  cols = st.columns(len(outputs))
@@ -87,12 +81,18 @@ def stream_response(messages: Iterable[ResponseChunk]):
87
  with cols[i]:
88
  st.expander(label=f"Agent {i+1}", expanded=False).write(output)
89
 
90
- # Clear layer outputs for the next iteration
91
- layer_outputs = {}
92
-
93
- # Yield the main agent's output
94
  yield message['delta']
95
 
 
 
 
 
 
 
 
 
 
96
  def set_moa_agent(
97
  main_model: str = default_config['main_model'],
98
  cycles: int = default_config['cycles'],
 
64
  }
65
 
66
 
67
+ async def async_stream_response(messages: Union[Iterable[ResponseChunk], AsyncIterable[ResponseChunk]]):
68
  layer_outputs = {}
 
 
 
 
 
 
 
69
 
70
+ async def process_message(message):
71
  if message['response_type'] == 'intermediate':
72
  layer = message['metadata']['layer']
73
+ if layer not in layer_outputs:
74
+ layer_outputs[layer] = []
75
+ layer_outputs[layer].append(message['delta'])
76
  else:
 
77
  for layer, outputs in layer_outputs.items():
78
  st.write(f"Layer {layer}")
79
  cols = st.columns(len(outputs))
 
81
  with cols[i]:
82
  st.expander(label=f"Agent {i+1}", expanded=False).write(output)
83
 
84
+ layer_outputs.clear()
 
 
 
85
  yield message['delta']
86
 
87
+ if isinstance(messages, AsyncIterable):
88
+ # Process asynchronous messages
89
+ async for message in messages:
90
+ await process_message(message)
91
+ else:
92
+ # Process synchronous messages
93
+ for message in messages:
94
+ await process_message(message)
95
+
96
  def set_moa_agent(
97
  main_model: str = default_config['main_model'],
98
  cycles: int = default_config['cycles'],