yym68686 commited on
Commit
4b94129
·
1 Parent(s): e0e0c2d

🐛 Bug: 1. Fix the bug where the option request attempts to parse the OK stream message.

Browse files
Files changed (1) hide show
  1. main.py +14 -5
main.py CHANGED
@@ -267,7 +267,7 @@ class LoggingStreamingResponse(Response):
267
  logger.info(f"{line}")
268
  if line.startswith("data:"):
269
  line = line.lstrip("data: ")
270
- if not line.startswith("[DONE]"):
271
  try:
272
  resp: dict = json.loads(line)
273
  input_tokens = safe_get(resp, "message", "usage", "input_tokens", default=0)
@@ -587,7 +587,8 @@ def weighted_round_robin(weights):
587
  import asyncio
588
  class ModelRequestHandler:
589
  def __init__(self):
590
- self.last_provider_index = -1
 
591
 
592
  def get_matching_providers(self, model_name, token):
593
  config = app.state.config
@@ -708,10 +709,18 @@ class ModelRequestHandler:
708
  status_code = 500
709
  error_message = None
710
  num_providers = len(providers)
711
- start_index = self.last_provider_index + 1 if use_round_robin else 0
 
 
 
 
 
 
 
 
712
  for i in range(num_providers + 1):
713
- self.last_provider_index = (start_index + i) % num_providers
714
- provider = providers[self.last_provider_index]
715
  try:
716
  response = await process_request(request, provider, endpoint, token)
717
  return response
 
267
  logger.info(f"{line}")
268
  if line.startswith("data:"):
269
  line = line.lstrip("data: ")
270
+ if not line.startswith("[DONE]") and not line.startswith("OK"):
271
  try:
272
  resp: dict = json.loads(line)
273
  input_tokens = safe_get(resp, "message", "usage", "input_tokens", default=0)
 
587
  import asyncio
588
  class ModelRequestHandler:
589
  def __init__(self):
590
+ self.last_provider_indices = defaultdict(lambda: -1)
591
+ self.locks = defaultdict(asyncio.Lock)
592
 
593
  def get_matching_providers(self, model_name, token):
594
  config = app.state.config
 
709
  status_code = 500
710
  error_message = None
711
  num_providers = len(providers)
712
+ model_name = request.model
713
+
714
+ if use_round_robin:
715
+ async with self.locks[model_name]:
716
+ self.last_provider_indices[model_name] = (self.last_provider_indices[model_name] + 1) % num_providers
717
+ start_index = self.last_provider_indices[model_name]
718
+ else:
719
+ start_index = 0
720
+
721
  for i in range(num_providers + 1):
722
+ current_index = (start_index + i) % num_providers
723
+ provider = providers[current_index]
724
  try:
725
  response = await process_request(request, provider, endpoint, token)
726
  return response