Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions tests/reasoning/test_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import unittest

from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager


class TestReasoningParser(ReasoningParser):
def is_reasoning_end(self, input_ids):
"""
Return True to simulate end of reasoning content.
"""
return True

def extract_content_ids(self, input_ids):
"""
Return input_ids directly for testing.
"""
return input_ids

def extract_reasoning_content(self, model_output, request):
"""
Used for testing non-streaming extraction.
"""
return model_output, model_output

def extract_reasoning_content_streaming(
self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids
):
"""
Return None for streaming extraction; minimal implementation for testing.
"""
return None


class TestReasoningParserManager(unittest.TestCase):
"""
Unit tests for ReasoningParserManager functionality.
"""

def setUp(self):
"""
Save original registry to restore after each test.
"""
self.original_parsers = ReasoningParserManager.reasoning_parsers.copy()

def tearDown(self):
"""
Restore original registry to avoid test pollution.
"""
ReasoningParserManager.reasoning_parsers = self.original_parsers.copy()

def test_register_and_get_parser(self):
"""
Test that a parser can be registered and retrieved successfully.
Verifies normal registration and retrieval functionality.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser", force=True)
parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser")
self.assertIs(parser_cls, TestReasoningParser)

def test_register_duplicate_without_force_raises(self):
"""
Test that registering a parser with an existing name without force raises KeyError.
Ensures duplicate registrations are handled correctly.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True)
with self.assertRaises(KeyError):
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False)

def test_register_non_subclass_raises(self):
"""
Test that registering a class not inheriting from ReasoningParser raises TypeError.
Ensures type safety for registered modules.
"""

class NotParser:
pass

with self.assertRaises(TypeError):
ReasoningParserManager.register_module(module=NotParser, name="not_parser")

def test_get_unregistered_parser_raises(self):
"""
Test that retrieving a parser that was not registered raises KeyError.
Ensures get_reasoning_parser handles unknown names correctly.
"""
with self.assertRaises(KeyError):
ReasoningParserManager.get_reasoning_parser("nonexistent_parser")


if __name__ == "__main__":
unittest.main()
Loading