Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions lib/rbs/inline/ast/declarations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,27 @@ def type(default_type)
end

# @rbs %a{pure}
# @rbs return Types::t?
# @rbs return: Types::t?
def literal_type
case node.value
infer_type_from_node(node.value)
end

# @rbs %a{pure}
# @rbs return: TypeName?
def constant_name
TypeName.new(name: node.name, namespace: Namespace.empty)
end

def start_line #: Integer
node.location.start_line
end

private

# @rbs node: Prism::Node
# @rbs return: Types::t?
def infer_type_from_node(node)
case node
when Prism::StringNode, Prism::InterpolatedStringNode
BuiltinNames::String.instance_type
when Prism::SymbolNode, Prism::InterpolatedSymbolNode
Expand All @@ -221,22 +239,79 @@ def literal_type
when Prism::FloatNode
BuiltinNames::Float.instance_type
when Prism::ArrayNode
BuiltinNames::Array.instance_type
infer_array_element_type(node)
when Prism::HashNode
BuiltinNames::Hash.instance_type
infer_hash_element_type(node)
when Prism::TrueNode, Prism::FalseNode
Types::Bases::Bool.new(location: nil)
end
end

# @rbs %a{pure}
# @rbs return: TypeName?
def constant_name
TypeName.new(name: node.name, namespace: Namespace.empty)
# @rbs node: Prism::ArrayNode
# @rbs return: Types::t
def infer_array_element_type(node)
return BuiltinNames::Array.instance_type if node.elements.empty?

element_types = [] #: Array[Types::t]
node.elements.each do |elem|
type = infer_type_from_node(elem)
return BuiltinNames::Array.instance_type unless type
element_types << type
end

element_types.uniq!(&:to_s)

# Union types are not currently supported.
case element_types.size
when 1
BuiltinNames::Array.instance_type(element_types.first || raise)
else
BuiltinNames::Array.instance_type(
Types::Bases::Any.new(location: nil)
)
end
end

def start_line #: Integer
node.location.start_line
# @rbs node: Prism::HashNode
# @rbs return: Types::t
def infer_hash_element_type(node)
return BuiltinNames::Hash.instance_type if node.elements.empty?

key_types = [] #: Array[Types::t]
value_types = [] #: Array[Types::t]

node.elements.each do |elem|
case elem
when Prism::AssocNode
key_type = infer_type_from_node(elem.key)
value_type = infer_type_from_node(elem.value)
return BuiltinNames::Hash.instance_type unless key_type && value_type
key_types << key_type
value_types << value_type
else
return BuiltinNames::Hash.instance_type
end
end

key_types.uniq!(&:to_s)
value_types.uniq!(&:to_s)

# Union types are not currently supported.
key_type = case key_types.size
when 1
key_types.first || raise
else
Types::Bases::Any.new(location: nil)
end

value_type = case value_types.size
when 1
value_types.first || raise
else
Types::Bases::Any.new(location: nil)
end

BuiltinNames::Hash.instance_type(key_type, value_type)
end
end

Expand Down
16 changes: 15 additions & 1 deletion sig/generated/rbs/inline/ast/declarations.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ module RBS
def type: (untyped default_type) -> Types::t

# @rbs %a{pure}
# @rbs return Types::t?
# @rbs return: Types::t?
%a{pure}
def literal_type: () -> Types::t?

Expand All @@ -112,6 +112,20 @@ module RBS
def constant_name: () -> TypeName?

def start_line: () -> Integer

private

# @rbs node: Prism::Node
# @rbs return: Types::t?
def infer_type_from_node: (Prism::Node node) -> Types::t?

# @rbs node: Prism::ArrayNode
# @rbs return: Types::t
def infer_array_element_type: (Prism::ArrayNode node) -> Types::t

# @rbs node: Prism::HashNode
# @rbs return: Types::t
def infer_hash_element_type: (Prism::HashNode node) -> Types::t
end

class SingletonClassDecl < ModuleOrClass[Prism::SingletonClassNode]
Expand Down
192 changes: 192 additions & 0 deletions test/rbs/inline/writer_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,198 @@ def test_constant__without_decl
RBS
end

def test_constant_decl_with_array_element_type_inference
output = translate(<<~RUBY)
STRINGS = ["a", "b", "c"]

INTEGERS = [1, 2, 3]

NESTED = [[1, 2], [3, 4]]

WORDS_ONE = %w(foo bar baz)

WORDS_TWO = %W(foo bar baz)

WORDS_THREE = %w[foo bar baz]

WORDS_FOUR = %W[foo bar baz]

SYMBOLS_ONE = %i(foo bar baz)

SYMBOLS_TWO = %I(foo bar baz)

SYMBOLS_THREE = %i[foo bar baz]

SYMBOLS_FOUR = %I[foo bar baz]
RUBY

assert_equal <<~RBS, output
STRINGS: ::Array[::String]

INTEGERS: ::Array[::Integer]

NESTED: ::Array[::Array[::Integer]]

WORDS_ONE: ::Array[::String]

WORDS_TWO: ::Array[::String]

WORDS_THREE: ::Array[::String]

WORDS_FOUR: ::Array[::String]

SYMBOLS_ONE: ::Array[::Symbol]

SYMBOLS_TWO: ::Array[::Symbol]

SYMBOLS_THREE: ::Array[::Symbol]

SYMBOLS_FOUR: ::Array[::Symbol]
RBS
end

def test_constant_decl_with_array_element_untyped_inference
output = translate(<<~RUBY)
EMPTY = []

MIXED = [1, "two", :three]

ARRAY_SPLAT = [*other]

SPLAT_RANGE = [*(1..5)]

SPLAT_ARRAY = [*[1, 2, 3]]

ARRAY_WITH_VARS = [x, y, z]

ARRAY_WITH_CONSTS = [FOO, BAR, BAZ]
RUBY

assert_equal <<~RBS, output
EMPTY: ::Array[untyped]

MIXED: ::Array[untyped]

ARRAY_SPLAT: ::Array[untyped]

SPLAT_RANGE: ::Array[untyped]

SPLAT_ARRAY: ::Array[untyped]

ARRAY_WITH_VARS: ::Array[untyped]

ARRAY_WITH_CONSTS: ::Array[untyped]
RBS
end

def test_constant_decl_with_unsupported_array_element_type_inference
output = translate(<<~RUBY)
ARRAY_NEW = Array.new

ARRAY_NEW_SIZED = Array.new(3)

ARRAY_NEW_BLOCK = Array.new(3) { |i| i }

ARRAY_NEW_DEFAULT = Array.new(3, "Ruby")

ARRAY_BRACKET = Array[1, 2, 3]

RANGE_TO_A = (1..5).to_a

RANGE_STR_TO_A = ("a".."e").to_a

KERNEL_ARRAY_INT = Array(1)

KERNEL_ARRAY_ARRAY = Array([1, 2, 3])

KERNEL_ARRAY_RANGE = Array(1..5)
RUBY

assert_equal <<~RBS, output
ARRAY_NEW: untyped

ARRAY_NEW_SIZED: untyped

ARRAY_NEW_BLOCK: untyped

ARRAY_NEW_DEFAULT: untyped

ARRAY_BRACKET: untyped

RANGE_TO_A: untyped

RANGE_STR_TO_A: untyped

KERNEL_ARRAY_INT: untyped

KERNEL_ARRAY_ARRAY: untyped

KERNEL_ARRAY_RANGE: untyped
RBS
end

def test_constant_decl_with_hash_element_type_inference
output = translate(<<~RUBY)
SYMBOL_KEY = { a: 1, b: 2 }

STRING_KEY = { "a" => 1, "b" => 2 }

NESTED = { a: { b: 1 }, c: { d: 2 } }
RUBY

assert_equal <<~RBS, output
SYMBOL_KEY: ::Hash[::Symbol, ::Integer]

STRING_KEY: ::Hash[::String, ::Integer]

NESTED: ::Hash[::Symbol, ::Hash[::Symbol, ::Integer]]
RBS
end

def test_constant_decl_with_hash_element_untyped_inference
output = translate(<<~RUBY)
EMPTY = {}

MIXED = { a: 1, "b" => :two, 3 => "three" }

HASH_SPLAT = {**other}

HASH_WITH_VARS = { x => y }

HASH_WITH_CONSTS = { FOO => BAR }
RUBY

assert_equal <<~RBS, output
EMPTY: ::Hash[untyped, untyped]

MIXED: ::Hash[untyped, untyped]

HASH_SPLAT: ::Hash[untyped, untyped]

HASH_WITH_VARS: ::Hash[untyped, untyped]

HASH_WITH_CONSTS: ::Hash[untyped, untyped]
RBS
end

def test_constant_decl_with_unsupported_hash_element_type_inference
output = translate(<<~RUBY)
HASH_NEW = Hash.new

HASH_NEW_DEFAULT = Hash.new(0)

HASH_BRACKET = Hash["a", 1, "b", 2]
RUBY

assert_equal <<~RBS, output
HASH_NEW: untyped

HASH_NEW_DEFAULT: untyped

HASH_BRACKET: untyped
RBS
end

def test_generic_class_module
output = translate(<<~RUBY)
# @rbs generic T
Expand Down