|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import importlib |
15 | 16 | from pathlib import Path |
| 17 | +import sys |
16 | 18 | import textwrap |
17 | 19 | from typing import Optional |
18 | 20 | from unittest.mock import AsyncMock |
@@ -898,5 +900,143 @@ def test_should_append_event_other_event(self): |
898 | 900 | assert self.runner._should_append_event(event, is_live_call=True) is True |
899 | 901 |
|
900 | 902 |
|
| 903 | +@pytest.fixture |
| 904 | +def user_agent_module(tmp_path, monkeypatch): |
| 905 | + """Fixture that creates a temporary user agent module for testing. |
| 906 | +
|
| 907 | + Yields a callable that creates an agent module with the given name and |
| 908 | + returns the loaded agent. |
| 909 | + """ |
| 910 | + created_modules = [] |
| 911 | + original_path = None |
| 912 | + |
| 913 | + def _create_agent(agent_dir_name: str): |
| 914 | + nonlocal original_path |
| 915 | + agent_dir = tmp_path / "agents" / agent_dir_name |
| 916 | + agent_dir.mkdir(parents=True, exist_ok=True) |
| 917 | + (tmp_path / "agents" / "__init__.py").write_text("", encoding="utf-8") |
| 918 | + (agent_dir / "__init__.py").write_text("", encoding="utf-8") |
| 919 | + |
| 920 | + agent_source = f"""\ |
| 921 | +from google.adk.agents.llm_agent import LlmAgent |
| 922 | +
|
| 923 | +class MyAgent(LlmAgent): |
| 924 | + pass |
| 925 | +
|
| 926 | +root_agent = MyAgent(name="{agent_dir_name}", model="gemini-2.0-flash") |
| 927 | +""" |
| 928 | + (agent_dir / "agent.py").write_text(agent_source, encoding="utf-8") |
| 929 | + |
| 930 | + monkeypatch.chdir(tmp_path) |
| 931 | + if original_path is None: |
| 932 | + original_path = str(tmp_path) |
| 933 | + sys.path.insert(0, original_path) |
| 934 | + |
| 935 | + module_name = f"agents.{agent_dir_name}.agent" |
| 936 | + module = importlib.import_module(module_name) |
| 937 | + created_modules.append(module_name) |
| 938 | + return module.root_agent |
| 939 | + |
| 940 | + yield _create_agent |
| 941 | + |
| 942 | + # Cleanup |
| 943 | + if original_path and original_path in sys.path: |
| 944 | + sys.path.remove(original_path) |
| 945 | + for mod_name in list(sys.modules.keys()): |
| 946 | + if mod_name.startswith("agents"): |
| 947 | + del sys.modules[mod_name] |
| 948 | + |
| 949 | + |
| 950 | +class TestRunnerInferAgentOrigin: |
| 951 | + """Tests for Runner._infer_agent_origin method.""" |
| 952 | + |
| 953 | + def setup_method(self): |
| 954 | + """Set up test fixtures.""" |
| 955 | + self.session_service = InMemorySessionService() |
| 956 | + self.artifact_service = InMemoryArtifactService() |
| 957 | + |
| 958 | + def test_infer_agent_origin_uses_adk_metadata_when_available(self): |
| 959 | + """Test that _infer_agent_origin uses _adk_origin_* metadata when set.""" |
| 960 | + agent = MockLlmAgent("test_agent") |
| 961 | + # Simulate metadata set by AgentLoader |
| 962 | + agent._adk_origin_app_name = "my_app" |
| 963 | + agent._adk_origin_path = Path("/workspace/agents/my_app") |
| 964 | + |
| 965 | + runner = Runner( |
| 966 | + app_name="my_app", |
| 967 | + agent=agent, |
| 968 | + session_service=self.session_service, |
| 969 | + artifact_service=self.artifact_service, |
| 970 | + ) |
| 971 | + |
| 972 | + origin_name, origin_path = runner._infer_agent_origin(agent) |
| 973 | + assert origin_name == "my_app" |
| 974 | + assert origin_path == Path("/workspace/agents/my_app") |
| 975 | + |
| 976 | + def test_infer_agent_origin_no_false_positive_for_direct_llm_agent(self): |
| 977 | + """Test that using LlmAgent directly doesn't trigger mismatch warning. |
| 978 | +
|
| 979 | + Regression test for GitHub issue #3143: Users who instantiate LlmAgent |
| 980 | + directly and run from a directory that is a parent of the ADK installation |
| 981 | + were getting false positive 'App name mismatch' warnings. |
| 982 | +
|
| 983 | + This also verifies that _infer_agent_origin returns None for ADK internal |
| 984 | + modules (google.adk.*). |
| 985 | + """ |
| 986 | + agent = LlmAgent( |
| 987 | + name="my_custom_agent", |
| 988 | + model="gemini-2.0-flash", |
| 989 | + ) |
| 990 | + |
| 991 | + runner = Runner( |
| 992 | + app_name="my_custom_agent", |
| 993 | + agent=agent, |
| 994 | + session_service=self.session_service, |
| 995 | + artifact_service=self.artifact_service, |
| 996 | + ) |
| 997 | + |
| 998 | + # Should return None for ADK internal modules |
| 999 | + origin_name, _ = runner._infer_agent_origin(agent) |
| 1000 | + assert origin_name is None |
| 1001 | + # No mismatch warning should be generated |
| 1002 | + assert runner._app_name_alignment_hint is None |
| 1003 | + |
| 1004 | + def test_infer_agent_origin_with_subclassed_agent_in_user_code( |
| 1005 | + self, user_agent_module |
| 1006 | + ): |
| 1007 | + """Test that subclassed agents in user code still trigger origin inference.""" |
| 1008 | + agent = user_agent_module("my_agent") |
| 1009 | + |
| 1010 | + runner = Runner( |
| 1011 | + app_name="my_agent", |
| 1012 | + agent=agent, |
| 1013 | + session_service=self.session_service, |
| 1014 | + artifact_service=self.artifact_service, |
| 1015 | + ) |
| 1016 | + |
| 1017 | + # Should infer origin correctly from user's code |
| 1018 | + origin_name, origin_path = runner._infer_agent_origin(agent) |
| 1019 | + assert origin_name == "my_agent" |
| 1020 | + assert runner._app_name_alignment_hint is None |
| 1021 | + |
| 1022 | + def test_infer_agent_origin_detects_mismatch_for_user_agent( |
| 1023 | + self, user_agent_module |
| 1024 | + ): |
| 1025 | + """Test that mismatched app_name is detected for user-defined agents.""" |
| 1026 | + agent = user_agent_module("actual_name") |
| 1027 | + |
| 1028 | + runner = Runner( |
| 1029 | + app_name="wrong_name", # Intentionally wrong |
| 1030 | + agent=agent, |
| 1031 | + session_service=self.session_service, |
| 1032 | + artifact_service=self.artifact_service, |
| 1033 | + ) |
| 1034 | + |
| 1035 | + # Should detect the mismatch |
| 1036 | + assert runner._app_name_alignment_hint is not None |
| 1037 | + assert "wrong_name" in runner._app_name_alignment_hint |
| 1038 | + assert "actual_name" in runner._app_name_alignment_hint |
| 1039 | + |
| 1040 | + |
901 | 1041 | if __name__ == "__main__": |
902 | 1042 | pytest.main([__file__]) |
0 commit comments