diff --git a/3rd/EmmyLuaCodeStyle b/3rd/EmmyLuaCodeStyle index 8500f3af1..8c4289b76 160000 --- a/3rd/EmmyLuaCodeStyle +++ b/3rd/EmmyLuaCodeStyle @@ -1 +1 @@ -Subproject commit 8500f3af178f097331d938378648078d023f4c7c +Subproject commit 8c4289b7617ccdb0b247a6171f111f28ac7ae969 diff --git a/changelog.md b/changelog.md index d4ecb8039..0cf094b2b 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,17 @@ ## Unreleased +## 3.17.0 +`2026-01-19` +* `NEW` Support `fun` syntax for inline generic function types in `@field` and `@type` annotations [#1170](https://github.com/LuaLS/lua-language-server/issues/1170) +* `FIX` Generic class inheritance with type arguments now works correctly (e.g., `class Bar: Foo`) [#1929](https://github.com/LuaLS/lua-language-server/issues/1929) +* `FIX` Method return types on generic classes now resolve correctly (e.g., `Box:getValue()` returns `string`) [#1863](https://github.com/LuaLS/lua-language-server/issues/1863) +* `FIX` Self-referential generic classes no longer cause infinite expansion in hover display [#1853](https://github.com/LuaLS/lua-language-server/issues/1853) +* `FIX` Generic type parameters now work in `@overload` annotations [#723](https://github.com/LuaLS/lua-language-server/issues/723) +* `FIX` Methods with `@generic T` and `@param self T` now correctly resolve return type to the receiver's concrete type (e.g., `List:identity()` returns `List`) [#1000](https://github.com/LuaLS/lua-language-server/issues/1000) +* `FIX` Fixed a CPU scheduling bug that prevented the full utilization of high-performance CPUs. +* `FIX` convert all keys to string in `--check` + ## 3.16.4 `2025-12-25` * `FIX` (VSCode) Broken `view document` diff --git a/locale/en-us/meta.lua b/locale/en-us/meta.lua index c0f8edd60..a153780bd 100644 --- a/locale/en-us/meta.lua +++ b/locale/en-us/meta.lua @@ -881,5 +881,5 @@ utf8.len = 'Returns the number of UTF-8 characters in string `s` that start between positions `i` and `j` (both inclusive).' utf8.offset = 'Returns the position (in bytes) where the encoding of the `n`-th character of `s` (counting from position `i`) starts.' -utf8.offset[55] = +utf8.offset['55'] = 'Returns the position of the n-th character of s (counting from byte position i) as two integers: The index (in bytes) where its encoding starts and the index (in bytes) where it ends.' diff --git a/locale/es-419/meta.lua b/locale/es-419/meta.lua index d7ef11f08..29e450ee5 100644 --- a/locale/es-419/meta.lua +++ b/locale/es-419/meta.lua @@ -880,5 +880,5 @@ utf8.len = 'Retorna el número de caracteres en UTF-8 en el string `s` que empiezan entre las posiciones `i` y `j` (ambos inclusive).' utf8.offset = 'Retorna la posición en bytes donde la codificación del caracter `n`-ésimo de `s` empieza, contado a partir de la posición `i`.' -utf8.offset[55] = +utf8.offset['55'] = 'Retorna la posición del carácter número `n` de `s` (contando desde la posición de byte `i`) como dos enteros: el índice (en bytes) donde empieza su codificación y el índice (en bytes) donde termina.' diff --git a/locale/ja-jp/meta.lua b/locale/ja-jp/meta.lua index d541a9131..114f1f270 100644 --- a/locale/ja-jp/meta.lua +++ b/locale/ja-jp/meta.lua @@ -876,5 +876,5 @@ utf8.len = 'バイト位置 `i` から `j`(両方含む)の間に含まれるUTF-8文字の数を返す。' utf8.offset = '文字列 `s` における `n` 番目の文字が始まるバイト位置をを返す。`i` が指定された場合、`i` から数えたバイト位置を返す。' -utf8.offset[55] = +utf8.offset['55'] = '文字列 `s` における `n` 番目の文字の位置を、2つの整数として返します(バイト位置 `i` から数えます):その文字のエンコードが開始するバイトインデックスと終了するバイトインデックス。' diff --git a/locale/pt-br/meta.lua b/locale/pt-br/meta.lua index 2b0810f41..e30b3ce2c 100644 --- a/locale/pt-br/meta.lua +++ b/locale/pt-br/meta.lua @@ -881,5 +881,5 @@ utf8.len = 'Retorna o número de caracteres UTF-8 na string `s` que começa entre as posições `i` e `j` (ambos inclusos).' utf8.offset = 'Retorna a posição (em bytes) onde a codificação do `n`-ésimo caractere de `s` inícia (contando a partir da posição `i`).' -utf8.offset[55] = +utf8.offset['55'] = 'Retorna a posição do n-ésimo caractere de `s` (contando a partir da posição de byte `i`) como dois inteiros: o índice (em bytes) onde sua codificação começa e o índice (em bytes) onde ela termina.' diff --git a/locale/zh-cn/meta.lua b/locale/zh-cn/meta.lua index 86ae136bb..b709fcdd5 100644 --- a/locale/zh-cn/meta.lua +++ b/locale/zh-cn/meta.lua @@ -860,5 +860,5 @@ utf8.len = '返回字符串 `s` 中 从位置 `i` 到 `j` 间 (包括两端) UTF-8 字符的个数。' utf8.offset = '返回编码在 `s` 中的第 `n` 个字符的开始位置(按字节数) (从位置 `i` 处开始统计)。' -utf8.offset[55] = +utf8.offset['55'] = '返回字符串 `s` 中第 `n` 个字符的位置(从字节位置 `i` 开始计数),以两个整数表示:其编码开始的字节索引和结束的字节索引。' diff --git a/locale/zh-tw/meta.lua b/locale/zh-tw/meta.lua index 52169a532..e0c385187 100644 --- a/locale/zh-tw/meta.lua +++ b/locale/zh-tw/meta.lua @@ -864,5 +864,5 @@ utf8.len = '回傳字串 `s` 中 從位置 `i` 到 `j` 間 (包括兩端) UTF-8 字元的個數。' utf8.offset = '回傳編碼在 `s` 中的第 `n` 個字元的開始位置(按位元組數)(從位置 `i` 處開始統計)。' -utf8.offset[55] = +utf8.offset['55'] = '以兩個整數回傳字串 `s` 中第 `n` 個字元的位置(從位元組位置 `i` 開始計數):其編碼開始的位元組索引與結束的位元組索引。' diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua index d97bb76d1..c0a53b355 100644 --- a/script/core/diagnostics/param-type-mismatch.lua +++ b/script/core/diagnostics/param-type-mismatch.lua @@ -5,7 +5,8 @@ local vm = require 'vm' local await = require 'await' ---@param defNode vm.node -local function expandGenerics(defNode) +---@param classGenericMap table? +local function expandGenerics(defNode, classGenericMap) ---@type parser.object[] local generics = {} for dn in defNode:eachObject() do @@ -20,27 +21,78 @@ local function expandGenerics(defNode) end for _, generic in ipairs(generics) do - local limits = generic.generic and generic.generic.extends - if limits then - defNode:merge(vm.compileNode(limits)) + -- First check if this generic is a class generic that can be resolved + local genericName = generic[1] + if classGenericMap and genericName and classGenericMap[genericName] then + defNode:merge(classGenericMap[genericName]) else - local unknownType = vm.declareGlobal('type', 'unknown') - defNode:merge(unknownType) + -- Fall back to constraint or unknown + local limits = generic.generic and generic.generic.extends + if limits then + defNode:merge(vm.compileNode(limits)) + else + local unknownType = vm.declareGlobal('type', 'unknown') + defNode:merge(unknownType) + end + end + end +end + +---@param uri uri +---@param source parser.object +---@return table? +local function getReceiverGenericMap(uri, source) + local callNode = source.node + if not callNode then + return nil + end + -- Only resolve generics for method calls (obj:method()), not static calls (Class.method()) + if callNode.type ~= 'getmethod' then + return nil + end + local receiver = callNode.node + if not receiver then + return nil + end + local receiverNode = vm.compileNode(receiver) + for rn in receiverNode:eachObject() do + if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then + local classGlobal = vm.getGlobal('type', rn.node[1]) + if classGlobal then + return vm.getClassGenericMap(uri, classGlobal, rn.signs) + end end end + return nil end ---@param funcNode vm.node ---@param i integer +---@param classGenericMap table? ---@return vm.node? -local function getDefNode(funcNode, i) +local function getDefNode(funcNode, i, classGenericMap) local defNode = vm.createNode() for src in funcNode:eachObject() do if src.type == 'function' or src.type == 'doc.type.function' then local param = src.args and src.args[i] if param then - defNode:merge(vm.compileNode(param)) + local paramNode = vm.compileNode(param) + -- Check for global type references that match class generic params + if classGenericMap then + local newNode = vm.createNode() + for pn in paramNode:eachObject() do + if pn.type == 'global' and pn.cate == 'type' and classGenericMap[pn.name] then + -- Replace the global type reference with the resolved type + newNode:merge(classGenericMap[pn.name]) + else + newNode:merge(pn) + end + end + defNode:merge(newNode) + else + defNode:merge(paramNode) + end if param[1] == '...' then defNode:addOptional() end @@ -51,7 +103,7 @@ local function getDefNode(funcNode, i) return nil end - expandGenerics(defNode) + expandGenerics(defNode, classGenericMap) return defNode end @@ -87,12 +139,14 @@ return function (uri, callback) end await.delay() local funcNode = vm.compileNode(source.node) + -- Get the class generic map for method calls on generic class instances + local classGenericMap = getReceiverGenericMap(uri, source) for i, arg in ipairs(source.args) do local refNode = vm.compileNode(arg) if not refNode then goto CONTINUE end - local defNode = getDefNode(funcNode, i) + local defNode = getDefNode(funcNode, i, classGenericMap) if not defNode then goto CONTINUE end diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index 1c55f3bf3..b648b9ab0 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -3,6 +3,100 @@ local guide = require 'parser.guide' local lang = require 'language' local vm = require 'vm' +--- Check if name is a generic parameter from a class context +---@param source parser.object The doc.type.name source +---@param name string The type name to check +---@param uri uri The file URI +---@return boolean +local function isClassGenericParam(source, name, uri) + -- Find containing doc node + local doc = guide.getParentTypes(source, { + ['doc.return'] = true, + ['doc.param'] = true, + ['doc.type'] = true, + ['doc.field'] = true, + ['doc.overload'] = true, + ['doc.vararg'] = true, + }) + if not doc then + return false + end + + -- Walk up to find a doc node with bindGroup (intermediate doc.type nodes don't have it) + while doc and not doc.bindGroup do + doc = doc.parent + end + if not doc then + return false + end + + -- Check bindGroup for class/alias with matching generic sign + local bindGroup = doc.bindGroup + if bindGroup then + for _, other in ipairs(bindGroup) do + if (other.type == 'doc.class' or other.type == 'doc.alias') and other.signs then + for _, sign in ipairs(other.signs) do + if sign[1] == name then + return true + end + end + end + end + end + + -- Check direct class reference (for doc.field, doc.overload, doc.operator) + if doc.class and doc.class.signs then + for _, sign in ipairs(doc.class.signs) do + if sign[1] == name then + return true + end + end + end + + -- Check if bound to a method on a generic class + -- Find the function from any doc in the bindGroup + local func = nil + if bindGroup then + for _, other in ipairs(bindGroup) do + local bindSource = other.bindSource + if bindSource then + if bindSource.type == 'function' then + -- doc.return binds directly to function + func = bindSource + break + else + -- doc.param binds to local param, find containing function + func = guide.getParentFunction(bindSource) + if func then + break + end + end + end + end + end + + -- If we found a function, check if it's a method on a generic class + if func and func.parent then + local parent = func.parent + if parent.type == 'setmethod' or parent.type == 'setfield' or parent.type == 'setindex' then + local classGlobal = vm.getDefinedClass(uri, parent.node) + if classGlobal then + for _, set in ipairs(classGlobal:getSets(uri)) do + if set.type == 'doc.class' and set.signs then + for _, sign in ipairs(set.signs) do + if sign[1] == name then + return true + end + end + end + end + end + end + end + + return false +end + return function (uri, callback) local state = files.getState(uri) if not state then @@ -25,6 +119,9 @@ return function (uri, callback) if name == '...' or name == '_' or name == 'self' then return end + if isClassGenericParam(source, name, uri) then + return + end if #vm.getDocSets(uri, name) > 0 then return end diff --git a/script/locale-loader.lua b/script/locale-loader.lua index 018da6806..c98529830 100644 --- a/script/locale-loader.lua +++ b/script/locale-loader.lua @@ -2,6 +2,7 @@ local function mergeKey(key, k) if not key then return k end + k = tostring(k) if k:sub(1, 1):match '%w' then return key .. '.' .. k else diff --git a/script/parser/compile.lua b/script/parser/compile.lua index d6a8722cc..aa82b25a4 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -3455,7 +3455,6 @@ local function parseGlobal() if attrs then glob.attrs = attrs attrs.parent = glob - glob.start = globalPos for i = 1, #attrs do if attrs[i][1] == 'close' then pushError { @@ -3465,8 +3464,6 @@ local function parseGlobal() } end end - else - glob.start = name.start end -- attributes after name diff --git a/script/parser/guide.lua b/script/parser/guide.lua index c407eca67..76a2d4fc0 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -177,7 +177,7 @@ local childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.function'] = {'#args', '#returns', 'comment'}, + ['doc.type.function'] = {'#args', '#returns', '#signs', 'comment'}, ['doc.type.table'] = {'#fields', 'comment'}, ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'name', 'extends'}, diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ad22af1d7..f31c54622 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -523,6 +523,8 @@ local function parseTypeUnitFunction(parent) args = {}, returns = {}, } + -- Parse optional generic params: fun(...) + typeUnit.signs = parseSigns(typeUnit) if not nextSymbolOrError('(') then return nil end @@ -617,6 +619,51 @@ local function parseTypeUnitFunction(parent) end end typeUnit.finish = getFinish() + -- Bind local generics from fun to type names within this function + if typeUnit.signs then + local generics = {} + for _, sign in ipairs(typeUnit.signs) do + generics[sign[1]] = sign + end + local function bindTypeNames(obj) + if not obj then return end + if obj.type == 'doc.type.name' and generics[obj[1]] then + obj.type = 'doc.generic.name' + obj.generic = generics[obj[1]] + elseif obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + bindTypeNames(t) + end + elseif obj.type == 'doc.type.array' then + bindTypeNames(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + bindTypeNames(field.name) + bindTypeNames(field.extends) + end + elseif obj.type == 'doc.type.sign' then + bindTypeNames(obj.node) + if obj.signs then + for _, s in ipairs(obj.signs) do + bindTypeNames(s) + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(obj.returns) do + bindTypeNames(ret) + end + end + end + for _, arg in ipairs(typeUnit.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(typeUnit.returns) do + bindTypeNames(ret) + end + end return typeUnit end @@ -1030,6 +1077,12 @@ local docSwitch = util.switch() } return result end + if extend.type == 'doc.extends.name' then + local signResult = parseTypeUnitSign(result, extend) + if signResult then + extend = signResult + end + end result.extends[#result.extends+1] = extend result.finish = getFinish() if not checkToken('symbol', ',', 1) then @@ -1850,7 +1903,9 @@ local function bindGeneric(binded) or doc.type == 'doc.return' or doc.type == 'doc.type' or doc.type == 'doc.class' - or doc.type == 'doc.alias' then + or doc.type == 'doc.alias' + or doc.type == 'doc.field' + or doc.type == 'doc.overload' then guide.eachSourceType(doc, 'doc.type.name', function (src) local name = src[1] if generics[name] then diff --git a/script/pub/pub.lua b/script/pub/pub.lua index ae57c0553..a5fd2058d 100644 --- a/script/pub/pub.lua +++ b/script/pub/pub.lua @@ -235,16 +235,12 @@ function m.recieve(block) selector:wait(-1) -- 遍历公共组 for _, brave in ipairs(m.publicBraves) do - if m.reciveFromPad(brave) then - return - end + m.reciveFromPad(brave) end -- 遍历所有专用组 for _, braveList in pairs(m.privateBraves) do for _, brave in ipairs(braveList) do - if m.reciveFromPad(brave) then - return - end + m.reciveFromPad(brave) end end else diff --git a/script/service/service.lua b/script/service/service.lua index bd087433f..116bc3bfe 100644 --- a/script/service/service.lua +++ b/script/service/service.lua @@ -168,13 +168,7 @@ function m.eventLoop() end end - local lastNetUpdateTime = 0 local function doSomething() - local now = time.monotonic() - if now - lastNetUpdateTime >= 100 then - net.update() - lastNetUpdateTime = now - end timer.update() pub.step(false) if await.step() then @@ -185,17 +179,24 @@ function m.eventLoop() end local function sleep() - idle() - for _ = 1, 10 do - net.update(100) - if doSomething() then - return + while true do + idle() + for _ = 1, 10 do + net.update(100) + if doSomething() then + return + end end + pub.step(true) end - pub.step(true) end while true do + net.update() + local clock = os.clock() + while os.clock() - clock < 0.1 do + doSomething() + end if doSomething() then goto CONTINUE end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index af7b7cc69..5267a037b 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -234,6 +234,108 @@ local function searchLiteralFieldFromTable(source, key, callback) end end +---@param obj parser.object +---@return boolean +local function containsGenericName(obj) + if not obj then + return false + end + if obj.type == 'doc.generic.name' then + return true + end + if obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + if containsGenericName(t) then + return true + end + end + elseif obj.type == 'doc.type.array' then + return containsGenericName(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + if containsGenericName(field.name) or containsGenericName(field.extends) then + return true + end + end + elseif obj.type == 'doc.type.sign' then + if obj.signs then + for _, s in ipairs(obj.signs) do + if containsGenericName(s) then + return true + end + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args or {}) do + if containsGenericName(arg.extends) then + return true + end + end + for _, ret in ipairs(obj.returns or {}) do + if containsGenericName(ret) then + return true + end + end + end + return false +end + +---Builds a map from generic parameter names to their concrete types +---@param uri uri +---@param classGlobal vm.global +---@param signs parser.object[] +---@return table? +function vm.getClassGenericMap(uri, classGlobal, signs) + for _, set in ipairs(classGlobal:getSets(uri)) do + if set.type == 'doc.class' and set.signs then + local resolved = {} + for i, signName in ipairs(set.signs) do + local signType = signs[i] + if signType and signName[1] then + resolved[signName[1]] = vm.compileNode(signType) + end + end + if next(resolved) then + return resolved + end + break + end + end + return nil +end + +---@param uri uri +---@param classGlobal vm.global +---@param field parser.object | vm.generic +---@param signs parser.object[] +---@return parser.object? +local function resolveGenericField(uri, classGlobal, field, signs) + if field.type ~= 'doc.field' or not field.extends then + return nil + end + if not containsGenericName(field.extends) then + return nil + end + local resolved = vm.getClassGenericMap(uri, classGlobal, signs) + if not resolved then + return nil + end + local newExtends = vm.cloneObject(field.extends, resolved) + if not newExtends then + return nil + end + return { + type = field.type, + start = field.start, + finish = field.finish, + parent = field.parent, + field = field.field, + extends = newExtends, + visible = field.visible, + optional = field.optional, + } +end + local searchFieldSwitch = util.switch() : case 'table' : call(function (_suri, source, key, pushResult) @@ -357,7 +459,16 @@ local searchFieldSwitch = util.switch() if not globalVar then return end - vm.getClassFields(suri, globalVar, key, pushResult) + vm.getClassFields(suri, globalVar, key, function (field, isMark) + if source.signs then + local newField = resolveGenericField(suri, globalVar, field, source.signs) + if newField then + pushResult(newField, isMark) + return + end + end + pushResult(field, isMark) + end) end) : case 'global' : call(function (suri, node, key, pushResult) @@ -565,7 +676,6 @@ function vm.getClassFields(suri, object, key, pushResult) for _, set in ipairs(sets) do if set.type == 'doc.class' then - -- look into extends(if field not found) if not searchedFields[key] and set.extends then for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then @@ -573,6 +683,14 @@ function vm.getClassFields(suri, object, key, pushResult) if extendType then searchClass(extendType, searchedFields) end + elseif extend.type == 'doc.type.sign' then + searchFieldSwitch(extend.type, suri, extend, key, function (field, isMark) + local fieldKey = guide.getKeyName(field) + if fieldKey and not searchedFields[fieldKey] then + hasFounded[fieldKey] = true + pushResult(field, isMark) + end + end) end end end @@ -1484,16 +1602,162 @@ local function bindReturnOfFunction(source, mfunc, index, args) if not returnObject then return end + + local resolveArgs = args + if source.func and source.func.type == 'getmethod' then + local receiver = source.func.node + if receiver then + resolveArgs = { receiver } + if args then + for i = 2, #args do + resolveArgs[#resolveArgs + 1] = args[i] + end + end + end + end + local returnNode = vm.compileNode(returnObject) + + local selfGenericResolved = nil + if source.func and source.func.type == 'getmethod' and mfunc.type == 'function' and mfunc.bindDocs then + local receiver = source.func.node + if receiver then + local receiverNode = vm.compileNode(receiver) + local selfGenericName = nil + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.param' and doc.param and doc.param[1] == 'self' then + if doc.extends then + for _, typeUnit in ipairs(doc.extends.types or {}) do + if typeUnit.type == 'doc.generic.name' then + selfGenericName = typeUnit[1] + break + end + end + end + break + end + end + if selfGenericName then + local filteredNode = vm.createNode() + for item in receiverNode:eachObject() do + if item.type == 'doc.type.sign' + or (item.type == 'global' and item.cate == 'type') + or item.type == 'doc.type.table' + or item.type == 'doc.type.array' then + filteredNode:merge(item) + end + end + if not filteredNode:isEmpty() then + selfGenericResolved = { [selfGenericName] = filteredNode } + else + selfGenericResolved = { [selfGenericName] = receiverNode } + end + end + end + end + for rnode in returnNode:eachObject() do if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(source), args) + if selfGenericResolved and rnode.sign then + local resolved = rnode.sign:resolve(guide.getUri(source), resolveArgs) or {} + for k, v in pairs(selfGenericResolved) do + resolved[k] = v + end + local protoNode = vm.compileNode(rnode.proto) + local result = vm.createNode() + for nd in protoNode:eachObject() do + if nd.type == 'global' or nd.type == 'variable' then + result:merge(nd) + else + local clonedObject = vm.cloneObject(nd, resolved) + if clonedObject then + result:merge(vm.compileNode(clonedObject)) + end + end + end + if protoNode:isOptional() then + result:addOptional() + end + returnNode = result + else + returnNode = rnode:resolve(guide.getUri(source), resolveArgs) + end break end end + + if mfunc.type == 'function' then + local hasUnresolvedGeneric = false + for rnode in returnNode:eachObject() do + if vm.isGenericUnsolved(rnode) then + hasUnresolvedGeneric = true + break + end + end + if hasUnresolvedGeneric then + local sign = vm.getSign(mfunc) + if sign and resolveArgs and #resolveArgs > 0 then + local resolved = sign:resolve(guide.getUri(source), resolveArgs) + if resolved and next(resolved) then + local newReturnNode = vm.createNode() + for rnode in returnNode:eachObject() do + local cloned = vm.cloneObject(rnode, resolved) + if cloned then + newReturnNode:merge(vm.compileNode(cloned)) + else + newReturnNode:merge(rnode) + end + end + returnNode = newReturnNode + end + end + end + end + + local call = source.parent + if not selfGenericResolved and call and call.type == 'call' then + local callNode = call.node + if callNode and (callNode.type == 'getmethod' or callNode.type == 'getfield') then + local receiver = callNode.node + if receiver then + local receiverNode = vm.compileNode(receiver) + for rn in receiverNode:eachObject() do + if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then + local classGlobal = vm.getGlobal('type', rn.node[1]) + if classGlobal then + local genericMap = vm.getClassGenericMap(guide.getUri(source), classGlobal, rn.signs) + if genericMap and mfunc.bindDocs then + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + if rtn.returnIndex == index then + local newRtn = vm.cloneObject(rtn, genericMap) + if newRtn then + returnNode = vm.compileNode(newRtn) + for rnode in returnNode:eachObject() do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(source), args) + break + end + end + end + break + end + end + break + end + end + end + break + end + end + end + end + end + end + if returnNode then for rnode in returnNode:eachObject() do - -- TODO: narrow type if rnode.type ~= 'doc.generic.name' then vm.setNode(source, rnode) end @@ -2068,7 +2332,11 @@ local compilerSwitch = util.switch() end) : case 'doc.generic.name' : call(function (source) - vm.setNode(source, source) + if source._resolved then + vm.setNode(source, source._resolved) + else + vm.setNode(source, source) + end end) : case 'doc.type.sign' : call(function (source) @@ -2088,6 +2356,9 @@ local compilerSwitch = util.switch() if ext.type == 'doc.type.table' then if vm.getGeneric(ext) then local resolved = vm.getGeneric(ext):resolve(uri, source.signs) + for obj in resolved:eachObject() do + obj.hideView = true + end vm.setNode(source, resolved) end end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index f1eaaf99d..d2c75eafa 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -12,7 +12,7 @@ local mt = {} mt.__index = mt mt.type = 'generic' ----@param source vm.object? +---@param source table? ---@param resolved? table ---@return vm.object? local function cloneObject(source, resolved) @@ -34,6 +34,21 @@ local function cloneObject(source, resolved) end return newName end + if source.type == 'doc.type.name' then + local key = source[1] + if resolved[key] then + local newName = { + type = 'doc.generic.name', + start = source.start, + finish = source.finish, + parent = source.parent, + [1] = source[1], + } + vm.setNode(newName, resolved[key], true) + newName._resolved = resolved[key] + return newName + end + end if source.type == 'doc.type' then local newType = { type = source.type, @@ -106,13 +121,43 @@ local function cloneObject(source, resolved) newDocFunc.args[i] = newObj end for i, ret in ipairs(source.returns) do - local newObj = cloneObject(ret, resolved) + local newObj = cloneObject(ret, resolved) newObj.parent = newDocFunc newObj.optional = ret.optional - newDocFunc.returns[i] = cloneObject(ret, resolved) + newDocFunc.returns[i] = newObj end return newDocFunc end + if source.type == 'doc.type.sign' and source.signs then + local needsClone = false + for _, sign in ipairs(source.signs) do + if sign.type == 'doc.type' then + for _, tp in ipairs(sign.types) do + if tp.type == 'doc.type.name' and resolved[tp[1]] then + needsClone = true + break + end + end + elseif sign.type == 'doc.type.name' and resolved[sign[1]] then + needsClone = true + end + if needsClone then break end + end + if needsClone then + local newSign = { + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, + node = source.node, + signs = {}, + } + for i, sign in ipairs(source.signs) do + newSign.signs[i] = cloneObject(sign, resolved) + end + return newSign + end + end return source end @@ -151,6 +196,14 @@ function vm.getGenericResolved(source) return source._resolved end +---@param source table +function vm.isGenericUnsolved(source) + if source.type == 'doc.generic.name' and not source._resolved then + return true + end + return false +end + ---@param source parser.object ---@param generic vm.generic function vm.setGeneric(source, generic) @@ -173,3 +226,10 @@ function vm.createGeneric(proto, sign) }, mt) return generic end + +---@param source table? +---@param resolved? table +---@return vm.object? +function vm.cloneObject(source, resolved) + return cloneObject(source, resolved) +end diff --git a/test/diagnostics/undefined-doc-name.lua b/test/diagnostics/undefined-doc-name.lua index 9a55108ac..a28a654a6 100644 --- a/test/diagnostics/undefined-doc-name.lua +++ b/test/diagnostics/undefined-doc-name.lua @@ -17,3 +17,51 @@ TEST [[ TEST [[ ---@alias B ]] + +-- Generic class methods should not warn about class generic params +TEST [[ +---@class Container +local Container = {} + +---@return T[] +function Container:getAll() + return {} +end +]] + +-- Inline class fields with generics should not warn +TEST [[ +---@class Box +---@field value T +]] + +-- Multiple generic params should all be recognized +TEST [[ +---@class Map +local Map = {} + +---@param key K +---@return V +function Map:get(key) +end +]] + +-- Variable name different from class name +TEST [[ +---@class Pool +local M = {} + +---@param item T +function M:push(item) end +]] + +-- Undefined types SHOULD still warn (control case) +TEST [[ +---@class Container +local Container = {} + +---@return +function Container:getBad() + return {} +end +]] diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 7f6d854c7..c54c82ee1 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4909,3 +4909,263 @@ function f(...args) print() end ]] + +TEST 'integer' [[ +---@class Foo +---@field a T + +---@class Bar: Foo + +---@type Bar +local x +local = x.a +]] + +TEST 'string' [[ +---@class GenericBase +---@field value T + +---@class StringHolder: GenericBase + +---@type StringHolder +local holder +local = holder.value +]] + +TEST 'boolean' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.val +]] + +TEST 'string' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.key +]] + +TEST 'string' [[ +---@class Box +---@field value T +local Box = {} + +---@return T +function Box:getValue() + return self.value +end + +---@type Box +local b + +local = b:getValue() +]] + +TEST 'integer' [[ +---@class Wrapper +---@field item V +local Wrapper = {} + +---@return V +function Wrapper:unwrap() + return self.item +end + +---@type Wrapper +local w + +local = w:unwrap() +]] + +-- Issue #1856: Generic class display format +-- Current behavior shows list<>|{...} - the <> indicates an unresolved generic +-- The resolved table type is also shown +TEST 'list<>|{ [integer]: string }' [[ +---@class list: {[integer]:T} + +---@generic T +---@param class `T` +---@return list +local function new_list(class) + return {} +end + +local = new_list('string') +]] + +-- Issue #1853: Recursive expansion on hover of generic type +-- Self-referential generic classes should not expand infinitely +TEST 'store' [[ +---@class store: {set:fun(self:store, key:integer, value:T), get:fun(self:store, key:integer):T} + +local ---@type store +]] + +-- Test composite return types with generics (T[] should resolve to string[]) +TEST 'string[]' [[ +---@class Container +local Container = {} + +---@return T[] +function Container:getAll() + return {} +end + +---@type Container +local c + +local = c:getAll() +]] + +-- Test nested generic return types (Wrapper should resolve to Wrapper) +TEST 'Wrapper' [[ +---@class Wrapper +---@field value V + +---@class Factory +local Factory = {} + +---@return Wrapper +function Factory:wrap() + return {} +end + +---@type Factory +local f + +local = f:wrap() +]] + +-- Issue #723: Generics in @overload +-- @generic should work with @overload annotations +TEST 'string' [[ +---@generic T +---@param x T +---@return T +---@overload fun(x: T): T +local function identity(x) + return x +end + +local = identity("hello") +]] + +-- Issue #723: Multiple generics in @overload +TEST 'integer' [[ +---@generic K, V +---@param k K +---@param v V +---@return V +---@overload fun(k: K, v: V): V +local function getValue(k, v) + return v +end + +local = getValue("key", 42) +]] + +-- Issue #1170: Generics in function type format (fun) +TEST 'string' [[ +---@type fun(x: T): T +local identity + +local = identity("hello") +]] + +-- Issue #1170: Multiple generics in function type +TEST 'boolean' [[ +---@type fun(k: K, v: V): V +local getSecond + +local = getSecond("key", true) +]] + +-- Issue #1170: Generic function in @field +TEST 'integer' [[ +---@class Mapper +---@field transform fun(input: T, fn: fun(x: T): U): U + +---@type Mapper +local m + +local = m.transform("hello", function(x) return #x end) +]] + +-- Issue #1532: Promise-like method chaining +-- Method returning self-type should preserve generic param through chain +TEST 'string' [[ +---@class Promise +---@field value T +local Promise = {} + +---@return Promise +function Promise:next(fn) + return self +end + +---@return T +function Promise:await() + return self.value +end + +---@type Promise +local p + +local = p:next(function() end):await() +]] + +-- Issue #1532: Multiple chained methods +TEST 'number' [[ +---@class Chain +local Chain = {} + +---@return Chain +function Chain:map(fn) + return self +end + +---@return Chain +function Chain:filter(fn) + return self +end + +---@return V +function Chain:first() + return nil +end + +---@type Chain +local c + +local = c:map(function() end):filter(function() end):first() +]] + +-- Issue #1000: Generic self parameter should resolve to concrete type +-- When @generic T and @param self T, calling on List should return List +TEST 'List' [[ +---@class List +local List = {} + +---@generic T +---@param self T +---@return T +function List:identity() + return self +end + +---@type List +local mylist + +local = mylist:identity() +]]