Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 392aa3d

Browse files
committed
WIP
1 parent e794d57 commit 392aa3d

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

src/codegate/db/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
131131
# Just store the model respnses in the list of JSON objects.
132132
for output in outputs:
133133
full_outputs.append(output.output)
134+
print("-----> FULL_OUTPUTS: ", full_outputs)
134135
output_db.output = json.dumps(full_outputs)
136+
print("-----> DB OUTPUT: ", output_db.output)
135137

136138
sql = text(
137139
"""

src/codegate/providers/copilot/provider.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,13 @@ def _ensure_output_processor(self) -> None:
575575
# Already initialized, no need to reinitialize
576576
return
577577

578-
# this is a hotfix - we shortcut before selecting the output pipeline for FIM
579-
# because our FIM output pipeline is actually empty as of now. We should fix this
580-
# but don't have any immediate need.
581-
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
582-
if is_fim:
583-
return
578+
# # this is a hotfix - we shortcut before selecting the output pipeline for FIM
579+
# # because our FIM output pipeline is actually empty as of now. We should fix this
580+
# # but don't have any immediate need.
581+
# is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
582+
# if is_fim:
583+
# return
584+
#
584585

585586
logger.debug("Tracking context for pipeline processing")
586587
self.sse_processor = SSEProcessor()
@@ -601,16 +602,25 @@ async def _process_stream(self):
601602
async def stream_iterator():
602603
while True:
603604
incoming_record = await self.stream_queue.get()
605+
if incoming_record.get("type") == "done":
606+
break
607+
604608
record_content = incoming_record.get("content", {})
605609

606610
streaming_choices = []
607611
for choice in record_content.get("choices", []):
612+
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
613+
if is_fim:
614+
content = choice.get("text", "")
615+
else:
616+
content = choice.get("delta", {}).get("content")
617+
608618
streaming_choices.append(
609619
StreamingChoices(
610620
finish_reason=choice.get("finish_reason", None),
611621
index=0,
612622
delta=Delta(
613-
content=choice.get("delta", {}).get("content"), role="assistant"
623+
content=content, role="assistant"
614624
),
615625
logprobs=None,
616626
)
@@ -624,12 +634,16 @@ async def stream_iterator():
624634
model=record_content.get("model", ""),
625635
object="chat.completion.chunk",
626636
)
637+
print("---> YIELDING", mr)
627638
yield mr
628639

629640
async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
641+
print("----> RECEIVED RECORD", record)
630642
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
643+
# if fim, then put the content into text
631644
sse_data = f"data:{chunk}\n\n".encode("utf-8")
632645
chunk_size = hex(len(sse_data))[2:] + "\r\n"
646+
print("WRITING CHUNK: ", chunk)
633647
self._proxy_transport_write(chunk_size.encode())
634648
self._proxy_transport_write(sse_data)
635649
self._proxy_transport_write(b"\r\n")
@@ -648,6 +662,7 @@ async def stream_iterator():
648662

649663
def _process_chunk(self, chunk: bytes):
650664
records = self.sse_processor.process_chunk(chunk)
665+
print("RECEIVED RECORDS", records)
651666

652667
for record in records:
653668
if self.stream_queue is None:
@@ -658,6 +673,9 @@ def _process_chunk(self, chunk: bytes):
658673
self.stream_queue.put_nowait(record)
659674

660675
def _proxy_transport_write(self, data: bytes):
676+
if not self.proxy.transport or self.proxy.transport.is_closing():
677+
print("TRIED TO WRITE TO A CLOSED TRANSPORT")
678+
return
661679
self.proxy.transport.write(data)
662680

663681
def data_received(self, data: bytes) -> None:
@@ -682,6 +700,7 @@ def data_received(self, data: bytes) -> None:
682700

683701
data = data[header_end + 4 :]
684702

703+
print("PROCESSING CHUNK: ", data)
685704
self._process_chunk(data)
686705

687706
def connection_lost(self, exc: Optional[Exception]) -> None:

src/codegate/providers/copilot/streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def process_chunk(self, chunk: bytes) -> list:
4242
data = json.loads(data_content)
4343
records.append({"type": "data", "content": data})
4444
except json.JSONDecodeError:
45-
print(f"Failed to parse JSON: {data_content}")
45+
print(f"SSEProcessor failed to parse JSON: {data_content}")
4646

47+
print("-----> RECORDS: ", records)
4748
return records
4849

4950
def get_pending(self):

0 commit comments

Comments
 (0)