diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 503db17d..f64ab1d3 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -6,18 +6,18 @@ jobs: name: coverage runs-on: ubuntu-latest container: - image: xd009642/tarpaulin + image: xd009642/tarpaulin:develop-nightly options: --security-opt seccomp=unconfined steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@main - - name: Generate code coverage + - name: Generate coverage report run: | - cargo tarpaulin --verbose --features lua54,vendored,async,send,serialize,macros --out xml --exclude-files benches --exclude-files build --exclude-files mlua_derive --exclude-files src/ffi --exclude-files tests + cargo +nightly tarpaulin --verbose --out xml --tests --exclude-files benches/* --exclude-files mlua-sys/src/*/* - - name: Upload to codecov.io - uses: codecov/codecov-action@v1 + - name: Upload report to codecov.io + uses: codecov/codecov-action@v4 with: token: ${{secrets.CODECOV_TOKEN}} fail_ci_if_error: false diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..662a04a3 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,68 @@ +name: Documentation (main) + +on: + push: + branches: [main] + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment +concurrency: + group: pages + cancel-in-progress: true + +jobs: + build: + name: Build Documentation + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@nightly + - uses: Swatinem/rust-cache@v2 + + - name: Build documentation + env: + RUSTDOCFLAGS: "--cfg docsrs" + run: | + cargo +nightly doc --no-deps \ + --features "lua55,vendored,async,send,serde,macros,anyhow,userdata-wrappers" + + - name: Create index redirect + run: | + echo ' + + + + Redirecting to mlua documentation + + + + +

Redirecting to mlua documentation...

+ + ' > target/doc/index.html + + - name: Setup Pages + uses: actions/configure-pages@v5 + + - name: Upload artifact + uses: actions/upload-pages-artifact@v4 + with: + path: target/doc + + deploy: + name: Deploy to GitHub Pages + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 10133f04..b67adccb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,98 +7,79 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] rust: [stable] - lua: [lua54, lua53, lua52, lua51, luajit, luau] + lua: [lua55, lua54, lua53, lua52, lua51, luajit, luau, luau-jit, luau-vector4] include: - - os: ubuntu-20.04 - target: x86_64-unknown-linux-gnu - - os: macos-latest - target: x86_64-apple-darwin - - os: windows-latest - target: x86_64-pc-windows-msvc + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + - os: macos-latest + target: aarch64-apple-darwin + - os: windows-latest + target: x86_64-pc-windows-msvc steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.rust }} - target: ${{ matrix.target }} - override: true - - uses: Swatinem/rust-cache@v1 - - name: Build ${{ matrix.lua }} vendored - run: | - cargo build --features "${{ matrix.lua }},vendored" - cargo build --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" - shell: bash - - name: Build ${{ matrix.lua }} pkg-config - if: ${{ matrix.os == 'ubuntu-20.04' && matrix.lua != 'lua54' }} - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends liblua5.3-dev liblua5.2-dev liblua5.1-0-dev libluajit-5.1-dev - cargo build --features "${{ matrix.lua }}" - - build_aarch64_cross_macos: - name: Cross-compile to aarch64-apple-darwin - runs-on: macos-latest - needs: build - strategy: - matrix: - lua: [lua54, lua53, lua52, lua51, luajit] - steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - target: aarch64-apple-darwin - override: true - - name: Cross-compile - run: cargo build --target aarch64-apple-darwin --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Build ${{ matrix.lua }} vendored + run: | + cargo build --features "${{ matrix.lua }},vendored" + cargo build --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers" + cargo build --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers,send" + shell: bash + - name: Build ${{ matrix.lua }} pkg-config + if: ${{ matrix.os == 'ubuntu-latest' && matrix.lua != 'lua55' }} + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends liblua5.4-dev liblua5.3-dev liblua5.2-dev liblua5.1-0-dev libluajit-5.1-dev + cargo build --features "${{ matrix.lua }}" build_aarch64_cross_ubuntu: name: Cross-compile to aarch64-unknown-linux-gnu - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest needs: build strategy: matrix: - lua: [lua54, lua53, lua52, lua51, luajit] + lua: [lua55, lua54, lua53, lua52, lua51, luajit] steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - target: aarch64-unknown-linux-gnu - override: true - - name: Install ARM compiler toolchain - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu libc6-dev-arm64-cross - shell: bash - - name: Cross-compile - run: cargo build --target aarch64-unknown-linux-gnu --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" - shell: bash + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + target: aarch64-unknown-linux-gnu + - name: Install ARM compiler toolchain + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu libc6-dev-arm64-cross + shell: bash + - name: Cross-compile + run: cargo build --target aarch64-unknown-linux-gnu --features "${{ matrix.lua }},vendored,async,send,serde,macros,anyhow,userdata-wrappers" + shell: bash build_armv7_cross_ubuntu: name: Cross-compile to armv7-unknown-linux-gnueabihf - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest needs: build strategy: matrix: - lua: [lua54, lua53, lua52, lua51] + lua: [lua55, lua54, lua53, lua52, lua51] steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - target: armv7-unknown-linux-gnueabihf - override: true - - name: Install ARM compiler toolchain - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends gcc-arm-linux-gnueabihf libc-dev-armhf-cross - shell: bash - - name: Cross-compile - run: cargo build --target armv7-unknown-linux-gnueabihf --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" - shell: bash + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + target: armv7-unknown-linux-gnueabihf + - name: Install ARM compiler toolchain + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends gcc-arm-linux-gnueabihf libc-dev-armhf-cross + shell: bash + - name: Cross-compile + run: cargo build --target armv7-unknown-linux-gnueabihf --features "${{ matrix.lua }},vendored,async,send,serde,macros,anyhow,userdata-wrappers" + shell: bash test: name: Test @@ -106,35 +87,35 @@ jobs: needs: build strategy: matrix: - os: [ubuntu-20.04, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] rust: [stable, nightly] - lua: [lua54, lua53, lua52, lua51, luajit, luajit52, luau] + lua: [lua55, lua54, lua53, lua52, lua51, luajit, luajit52, luau, luau-jit, luau-vector4] include: - - os: ubuntu-20.04 - target: x86_64-unknown-linux-gnu - - os: macos-latest - target: x86_64-apple-darwin - - os: windows-latest - target: x86_64-pc-windows-msvc + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + - os: macos-latest + target: aarch64-apple-darwin + - os: windows-latest + target: x86_64-pc-windows-msvc steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.rust }} - target: ${{ matrix.target }} - override: true - - uses: Swatinem/rust-cache@v1 - - name: Run ${{ matrix.lua }} tests - run: | - cargo test --features "${{ matrix.lua }},vendored" - cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" - shell: bash - - name: Run compile tests (macos lua54) - if: ${{ matrix.os == 'macos-latest' && matrix.lua == 'lua54' }} - run: | - TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored" -- --ignored - TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" -- --ignored - shell: bash + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Run ${{ matrix.lua }} tests + run: | + cargo test --features "${{ matrix.lua }},vendored" + cargo test --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers" + cargo test --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers,send" + shell: bash + - name: Run compile tests (macos lua55) + if: ${{ matrix.os == 'macos-latest' && matrix.lua == 'lua55' }} + run: | + TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored" --tests -- --ignored + TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored,async,send,serde,macros" --tests -- --ignored + shell: bash test_with_sanitizer: name: Test with address sanitizer @@ -142,25 +123,52 @@ jobs: needs: build strategy: matrix: - os: [ubuntu-20.04] + os: [ubuntu-latest] rust: [nightly] - lua: [lua54, lua53, lua52, lua51, luajit, luau] + lua: [lua55, lua54, lua53, lua52, lua51, luajit, luau, luau-jit, luau-vector4] include: - - os: ubuntu-20.04 - target: x86_64-unknown-linux-gnu + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.rust }} - target: ${{ matrix.target }} - override: true - - uses: Swatinem/rust-cache@v1 - - name: Run ${{ matrix.lua }} tests with address sanitizer - run: | - RUSTFLAGS="-Z sanitizer=address" \ - cargo test --tests --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" --target x86_64-unknown-linux-gnu -- --skip test_too_many_recursions - shell: bash + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Run ${{ matrix.lua }} tests with address sanitizer + run: | + cargo test --tests --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow" --target x86_64-unknown-linux-gnu -- --skip test_too_many_recursions + cargo test --tests --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers,send" --target x86_64-unknown-linux-gnu -- --skip test_too_many_recursions + shell: bash + env: + RUSTFLAGS: -Z sanitizer=address + + test_with_memory_limit: + name: Test with memory limit + runs-on: ${{ matrix.os }} + needs: build + strategy: + matrix: + os: [ubuntu-latest] + rust: [nightly] + lua: [lua55, lua54, lua53, lua52, lua51, luajit, luau, luau-jit, luau-vector4] + include: + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + steps: + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Run ${{ matrix.lua }} tests with forced memory limit + run: | + cargo test --tests --features "${{ matrix.lua }},vendored,async,send,serde,macros,anyhow,userdata-wrappers" + shell: bash + env: + RUSTFLAGS: --cfg=force_memory_limit test_modules: name: Test modules @@ -168,27 +176,26 @@ jobs: needs: build strategy: matrix: - os: [ubuntu-20.04, macos-latest] + os: [ubuntu-latest, macos-latest] rust: [stable] - lua: [lua54, lua53, lua52, lua51, luajit] + lua: [lua55, lua54, lua53, lua52, lua51, luajit] include: - - os: ubuntu-20.04 - target: x86_64-unknown-linux-gnu - - os: macos-latest - target: x86_64-apple-darwin + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + - os: macos-latest + target: aarch64-apple-darwin steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.rust }} - target: ${{ matrix.target }} - override: true - - uses: Swatinem/rust-cache@v1 - - name: Run ${{ matrix.lua }} module tests - run: | - (cd tests/module && cargo build --release --features "${{ matrix.lua }}") - (cd tests/module/loader && cargo test --release --features "${{ matrix.lua }},vendored") - shell: bash + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Run ${{ matrix.lua }} module tests + run: | + (cd tests/module && cargo build --release --features "${{ matrix.lua }}") + (cd tests/module/loader && cargo test --release --features "${{ matrix.lua }},vendored") + shell: bash test_modules_windows: name: Test modules on Windows @@ -201,42 +208,96 @@ jobs: run: shell: msys2 {0} steps: - - uses: msys2/setup-msys2@v2 - - uses: actions/checkout@v2 - - name: Install Rust & Lua - run: | - pacman -S --noconfirm mingw-w64-x86_64-rust mingw-w64-x86_64-lua mingw-w64-x86_64-luajit mingw-w64-x86_64-pkg-config - - name: Run ${{ matrix.lua }} module tests - run: | - (cd tests/module && cargo build --release --features "${{ matrix.lua }}") - (cd tests/module/loader && cargo test --release --features "${{ matrix.lua }}") + - uses: msys2/setup-msys2@v2 + - uses: actions/checkout@main + - name: Install Rust & Lua + run: | + pacman -S --noconfirm mingw-w64-x86_64-rust mingw-w64-x86_64-lua mingw-w64-x86_64-luajit mingw-w64-x86_64-pkg-config + - name: Run ${{ matrix.lua }} module tests + run: | + (cd tests/module && cargo build --release --features "${{ matrix.lua }}") + (cd tests/module/loader && cargo test --release --features "${{ matrix.lua }}") + + test_wasm32_emscripten: + name: Test on wasm32-unknown-emscripten + runs-on: ubuntu-latest + needs: build + strategy: + matrix: + lua: [lua55, lua54, lua53, lua52, lua51, luau] + steps: + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + target: wasm32-unknown-emscripten + - name: Install Emscripten + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends emscripten + - name: Run ${{ matrix.lua }} tests + run: | + cargo test --tests --features "${{ matrix.lua }},vendored" + cargo test --tests --features "${{ matrix.lua }},vendored,async,serde,macros,anyhow,userdata-wrappers" + + test_wasm32_wasip2: + name: Test on wasm32-wasip2 + runs-on: ubuntu-latest + needs: build + strategy: + matrix: + lua: [lua55, lua54, lua53, lua52, lua51] + steps: + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: nightly-2025-10-02 + target: wasm32-wasip2 + - name: Install wasi-sdk/Wasmtime + working-directory: ${{ runner.tool_cache }} + run: | + wasi_sdk=29 + wasmtime=v40.0.1 + + curl -LO https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-$wasi_sdk/wasi-sdk-$wasi_sdk.0-x86_64-linux.tar.gz + tar xf wasi-sdk-$wasi_sdk.0-x86_64-linux.tar.gz + WASI_SDK_PATH=`pwd`/wasi-sdk-$wasi_sdk.0-x86_64-linux + echo "WASI_SDK_PATH=$WASI_SDK_PATH" >> $GITHUB_ENV + echo "CC_wasm32_wasip2=$WASI_SDK_PATH/bin/clang" >> $GITHUB_ENV + echo "CARGO_TARGET_WASM32_WASIP2_LINKER=$WASI_SDK_PATH/bin/clang" >> $GITHUB_ENV + echo "CARGO_TARGET_WASM32_WASIP2_RUSTFLAGS=-Clink-arg=-Wl,--export=cabi_realloc" >> $GITHUB_ENV + + curl -LO https://github.com/bytecodealliance/wasmtime/releases/download/$wasmtime/wasmtime-$wasmtime-x86_64-linux.tar.xz + tar xf wasmtime-$wasmtime-x86_64-linux.tar.xz + echo "CARGO_TARGET_WASM32_WASIP2_RUNNER=`pwd`/wasmtime-$wasmtime-x86_64-linux/wasmtime -W exceptions" >> $GITHUB_ENV + - name: Run ${{ matrix.lua }} tests + run: | + cargo test --target wasm32-wasip2 --tests --features "${{ matrix.lua }},vendored" + cargo test --target wasm32-wasip2 --tests --features "${{ matrix.lua }},vendored,serde,macros,anyhow,userdata-wrappers" rustfmt: name: Rustfmt - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - components: rustfmt - override: true - - run: cargo fmt -- --check + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@nightly + with: + components: rustfmt + - run: cargo fmt -- --check clippy: - name: Clippy check - runs-on: ubuntu-20.04 + name: Clippy + runs-on: ubuntu-latest strategy: matrix: - lua: [lua54, lua53, lua52, lua51, luajit, luau] + lua: [lua55, lua54, lua53, lua52, lua51, luajit, luau, luau-jit, luau-vector4] steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 + - uses: actions/checkout@main + - uses: dtolnay/rust-toolchain@stable with: - toolchain: nightly - components: clippy - override: true - - uses: actions-rs/clippy-check@v1 + toolchain: nightly + components: clippy + - uses: giraffate/clippy-action@v1 with: - token: ${{ secrets.GITHUB_TOKEN }} - args: --features "${{ matrix.lua }},vendored,async,send,serialize,macros,parking_lot" + reporter: 'github-pr-review' + clippy_flags: --features "${{ matrix.lua }},vendored,async,send,serde,macros,anyhow,userdata-wrappers" diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml new file mode 100644 index 00000000..1cd3c1c9 --- /dev/null +++ b/.github/workflows/typos.yml @@ -0,0 +1,22 @@ +name: Spelling Check +on: + pull_request: + workflow_dispatch: + +permissions: + contents: read + +env: + CLICOLOR: 1 + +jobs: + spelling: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@main + - name: Check spelling + uses: crate-ci/typos@v1.42.1 + with: + config: ./typos.toml diff --git a/.gitignore b/.gitignore index 5c21d7d9..fcfec84e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ Cargo.lock .vscode/ .DS_Store +.stignore diff --git a/CHANGELOG.md b/CHANGELOG.md index f6a697ce..fb8e8562 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,458 @@ +## v0.12.0-rc.1 (Apr 21, 2026) + +- Rust 2024 edition +- Removed `Error::ToLuaConversionError` variant as it was unused (and not practically useful) +- New modules to group data types: `chunk`, `debug`, `error`, `function`, `table`, `string`, `state`, `thread`, `userdata`, `luau` +- Support `__todebugstring` metamethod for pretty formatting userdata value (for debugging) +- New `MaybeSync` trait that is required for userdata types +- Removed lifetime from `BorrowedStr` and `BorrowedBytes` +- New `Thread` methods: `is_resumable`, `is_running`, `is_finished`, `is_error` +- Added `Thread::state` to get raw Lua state pointer +- Luau `TextRequirer` is renamed to `FsRequirer` +- GC interface refactor: `Lua::gc_inc/Lua::gc_gen` is replaced with `gc_set_mode` +- Added `GcIncParams` and `GcGenParams` for GC tuning +- New `UserDataMethods::add_method_once` and `UserDataMethods::add_async_method_once` +- Initial Luau integer64 type support +- Changed interface of `Function::wrap/wrap_mut/wrap_async` to support any Error type +- Changed `AnyUserData::type_name` to return `LuaString` instead +- Added `UserDataOwned` wrapper to take ownership of userdata `T` and implements `FromLua` + +## v0.11.6 (Jan 27, 2026) + +- Added Lua 5.5 support (`lua55` feature flag) +- Luau updated to 0.705+ +- Added `AnyUserData::is_proxy` method to check if userdata is a proxy +- Added `num_params`, `num_upvalues`, `is_vararg` to `FunctionInfo` + +## v0.11.5 (Nov 22, 2025) + +- Luau updated to 0.701 +- Added `Lua::set_memory_category` and `Lua::heap_dump` functions to profile (Luau) memory +- Added `Lua::type_metatable` helper to get metatable of a primitive type +- Added `Lua::traceback` function to generate stack traces at different levels +- Added `add_method_once` /`add_async_method_once` UserData methods (experimental) +- Make `AnyUserData::type_name` public +- impl `IntoLuaMulti` for `&MultiValue` +- Bugfixes and async perf improvements + +## v0.11.4 (Sep 29, 2025) + +- Make `Value::to_serializable` public +- Add new serde option `detect_mixed_tables` (to encode mixed array+map tables) +- Add `ObjectLike::get_path` helper (for tables and userdata) + +## v0.11.3 (Aug 30, 2025) + +- Add `Lua::yield_with` to use as `coroutine.yield` functional replacement in async functions for any Lua +- Do not try to yield at non-yielable points in Luau interrupt (#632) +- Add `Buffer::cursor` method (Luau) +- Add `Lua::create_buffer_with_capacity` method (Luau) +- Make Lua reference values cheap to clone (only increments ref count) +- Fix panic on large (>67M entries) table creation + +## v0.11.2 (Aug 10, 2025) + +- Faster stack push for `Variadic` +- Fix handling Windows paths with drive letter in Luau require (#623) +- Make Luau registered aliases ascii case-insensitive (#620) +- Fix deserializing negative zeros `-0.0` (#618) + +## v0.11.1 (Jul 15, 2025) + +- Fixed bug exhausting Lua auxiliary stack and leaving it without reserve (#615) +- `Lua::push_c_function` now correctly handles OOM for Lua 5.1 and Luau + +## v0.11.0 (Jul 14, 2025) + +Changes since v0.11.0-beta.3 + +- Allow linking external Lua libraries in a build script (e.g. pluto) using `external` mlua-sys feature flag +- `Lua::inspect_stack` takes a callback with `&Debug` argument, instead of returning `Debug` directly +- Added `Debug::function` method to get function running at a given level +- `Debug::curr_line` is deprecated in favour of `Debug::current_line` that returns `Option` +- Added `Lua::set_globals` method to replace global environment +- `Table::set_metatable` now returns `Result<()>` (this operation can fail in sandboxed Luau mode) +- `impl ToString` replaced with `Into` in `UserData` registration +- `Value::as_str` and `Value::as_string_lossy` methods are deprecated (as they are non-idiomatic) +- Bugfixes and improvements + +## v0.11.0-beta.3 (Jun 23, 2025) + +- Luau in sandboxed mode has reduced options in `collectgarbage` function (to follow the official doc) +- `Function::deep_clone` now returns `Result` as this operation can trigger memory errors +- Luau "Require" resolves included Lua files relative to the current directory (#605) +- Fixed bug when finalizing `AsyncThread` on drop (`call_async` methods family) + +## v0.11.0-beta.2 (Jun 12, 2025) + +- Lua 5.4 updated to 5.4.8 +- Terminate Rust `Future` when `AsyncThread` is dropped (without relying on Lua GC) +- Added `loadstring` function to Luau +- Make `AsChunk` trait dyn-friendly +- Luau `Require` trait synced with Luau 0.674 +- Luau `Require` trait methods now can return `Error` variant (in `NavigateError` enum) +- Added `__type` to `Error`'s userdata metatable (for `typeof` function) +- `parking_log/send_guard` is moved to `userdata-wrappers` feature flag +- New `serde` feature flag to replace `serialize` (the old one is still available) + +## v0.11.0-beta.1 (May 7th, 2025) + +- New "require-by-string" for Luau (with `Require` trait and async support) +- Added `Thread::resume_error` support for Luau +- 52 bit integers support for Luau (this is a breaking change) +- New features for Luau compiler (constants, disabled builtins, known members) +- `AsyncThread` changed to `AsyncThread` (`A` pushed to stack immediately) +- Lifetime `'a` moved from `AsChunk<'a>` to `AsChunk::source where Self: 'a` +- `Lua::scope` pass `&Scope` instead of `&mut Scope` to closure +- Added global hooks support (Lua 5.1+) +- Added per-thread hooks support (Lua 5.1+) +- `Lua::init_from_ptr` renamed to `Lua::get_or_init_from_ptr` and returns `&Lua` +- `Lua:load_from_function` is deprecated (this is `register_module` now) +- Added `Lua::register_module` and `Lua::preload_module` + +## v0.10.4 (May 5th, 2025) + +- Luau updated to 0.672 +- New serde option `encode_empty_tables_as_array` to serialize empty tables as arrays +- Added `WeakLua` and `Lua::weak()` to create weak references to Lua state +- Trigger abort when Luau userdata destructors are panic (Luau GC does not support it) +- Added `AnyUserData::type_id()` method to get the type id of the userdata +- Added `Chunk::name()`, `Chunk::environment()` and `Chunk::mode()` functions +- Support borrowing underlying wrapped types for `UserDataRef` and `UserDataRefMut` (under `userdata-wrappers` feature) +- Added large (52bit) integers support for Luau +- Enable `serde` for `bstr` if `serialize` feature flag is enabled +- Recursive warnings (Lua 5.4) are no longer allowed +- Implemented `IntoLua`/`FromLua` for `BorrowedString` and `BorrowedBytes` +- Implemented `IntoLua`/`FromLua` for `char` +- Enable `Thread::reset()` for all Lua versions (limited support for 5.1-5.3) +- Bugfixes and improvements + +## v0.10.3 (Jan 27th, 2025) + +- Set `Default` for `Value` to be `Nil` +- Allow exhaustive match on `Value` (#502) +- Add `Table::set_safeenv` method (Luau) + +## v0.10.2 (Dec 1st, 2024) + +- Switch proc-macro-error to proc-macro-error2 (#493) +- Do not allow Lua to run GC finalizers on ref thread (#491) +- Fix chunks loading in Luau when memory limit is enforced (#488) +- Added `String::wrap` method to wrap arbitrary `AsRef<[u8]>` into `impl IntoLua` +- Better FreeBSD/OpenBSD support (thanks to cos) +- Delay "any" userdata metatable creation until first instance is created (#482) +- Reduce amount of generated code for `UserData` (less generics) + +## v0.10.1 (Nov 9th, 2024) + +- Minimal Luau updated to 0.650 +- Added Luau native vector library support (this can change behavior if you use `vector` function!) +- Added Lua `String::display` method +- Improved pretty-printing for Lua tables (#478) +- Added `Scope::create_any_userdata` to create Lua objects from any non-`'static` Rust types +- Added `AnyUserData::destroy` method +- New `userdata-wrappers` feature to `impl UserData` for `Rc`/`Arc`/`Rc>`/`Arc>` (similar to v0.9) +- `UserDataRef` in `send` mode now uses shared lock if `T: Sync` (and exclusive lock otherwise) +- Added `Scope::add_destructor` to attach custom destructors +- Added `Lua::try_app_data_ref` and `Lua::try_app_data_mut` methods +- Added `From` and `Into` support to `MultiValue` and `Variadic` types +- Bug fixes and improvements (#477 #479) + +## v0.10.0 (Oct 25th, 2024) + +Changes since v0.10.0-rc.1 + +- Added `error-send` feature flag (disabled by default) to require `Send + Sync` for `Error` +- Some performance improvements + +## v0.10.0-rc.1 + +- `Lua::scope` is back +- Support yielding from hooks for Lua 5.3+ +- Support setting metatable for Lua builtin types (number/string/function/etc) +- Added `LuaNativeFn`/`LuaNativeFnMut`/`LuaNativeAsyncFn` traits for using in `Function::wrap` +- Added `Error::chain` method to return iterator over nested errors +- Added `Lua::exec_raw` helper to execute low-level Lua C API code +- Added `Either` enum to combine two types into a single one +- Added a new `Buffer` type for Luau +- Added `Value::is_error` and `Value::as_error` helpers +- Added `Value::Other` variant to represent unknown Lua types (eg LuaJIT CDATA) +- Added (optional) `anyhow` feature to implement `IntoLua` for `anyhow::Error` +- Added `IntoLua`/`FromLua` for `OsString`/`OsStr` and `PathBuf`/`Path` + +## v0.10.0-beta.2 + +- Updated `ThreadStatus` enum to include `Running` and `Finished` variants. +- `Error::CoroutineInactive` renamed to `Error::CoroutineUnresumable`. +- `IntoLua`/`IntoLuaMulti` now uses `impl trait` syntax for args (shorten from `a.get::<_, T>` to `a.get::`). +- Removed undocumented `Lua::into_static`/`from_static` methods. +- Futures now require `Send` bound if `send` feature is enabled. +- Dropped lifetime from `UserDataMethods` and `UserDataFields` traits. +- `Compiler::compile()` now returns `Result` (Luau). +- Removed `Clone` requirement from `UserDataFields::add_field()`. +- `TableExt` and `AnyUserDataExt` traits were combined into `ObjectLike` trait. +- Disabled `send` feature in module mode (since we don't have exclusive access to Lua). +- `Chunk::set_environment()` takes `Table` instead of `IntoLua` type. +- Reduced the compile time contribution of `next_key_seed` and `next_value_seed`. +- Reduced the compile time contribution of `serde_userdata`. +- Performance improvements. + +## v0.10.0-beta.1 + +- Dropped `'lua` lifetime (subtypes now store a weak reference to Lua) +- Removed (experimental) owned types (they no longer needed) +- Make Lua types truly `Send` and `Sync` (when enabling `send` feature flag) +- Removed `UserData` impl for Rc/Arc types ("any" userdata functions can be used instead) +- `Lua::replace_registry_value` takes `&mut RegistryKey` +- `Lua::scope` temporary disabled (will be re-added in the next release) + +## v0.9.9 + +- Minimal Luau updated to 0.629 +- Fixed bug when attempting to reset or resume already running coroutines (#416). +- Added `RegistryKey::id()` method to get the underlying Lua registry key id. + +## v0.9.8 + +- Fixed serializing same table multiple times (#408) +- Use `mlua-sys` v0.6 (to support Luau 0.624+) +- Fixed cross compilation of windows dlls from unix (#394) + +## v0.9.7 + +- Implemented `IntoLua` for `RegistryKey` +- Mark `__idiv` metamethod as available for luau +- Added `Function::deep_clone()` method (Luau) +- Added `SerializeOptions::detect_serde_json_arbitrary_precision` option +- Added `Lua::create_buffer()` method (Luau) +- Support serializing buffer type as a byte slice (Luau) +- Perf: Implemented `push_into_stack`/`from_stack` for `Option` +- Added `Lua::create_ser_any_userdata()` method + +## v0.9.6 + +- Added `to_pointer` function to `Function`/`Table`/`Thread` +- Implemented `IntoLua` for `&Value` +- Implemented `FromLua` for `RegistryKey` +- Faster (~5%) table array traversal during serialization +- Some performance improvements for bool/int types + +## v0.9.5 + +- Minimal Luau updated to 0.609 +- Luau max stack size increased to 1M (from 100K) +- Implemented `IntoLua` for refs to `String`/`Table`/`Function`/`AnyUserData`/`Thread` + `RegistryKey` +- Implemented `IntoLua` and `FromLua` for `OwnedThread`/`OwnedString` +- Fixed `FromLua` derive proc macro to cover more cases + +## v0.9.4 + +- Fixed loading all-in-one modules under mixed states (eg. main state and coroutines) + +## v0.9.3 + +- WebAssembly support (`wasm32-unknown-emscripten` target) +- Performance improvements (faster Lua function calls for lua51/jit/luau) + +## v0.9.2 + +- Added binary modules support to Luau +- Added Luau package module (uses `StdLib::PACKAGE`) with loaders (follows lua5.1 interface) +- Added support of Luau 0.601+ buffer type (represented as userdata in Rust) +- LuaJIT `cdata` type is also represented as userdata in Rust (instead of panic) +- Vendored LuaJIT switched to rolling vanilla (from openresty) +- Added `Table::for_each` method for fast table pairs traversal (faster than `pairs`) +- Performance improvements around table traversal (and faster serialization) +- Bug fixes and improvements + +## v0.9.1 + +- impl Default for Lua +- impl IntoLuaMulti for `std::result::Result<(), E>` +- Fix using wrong userdata index after processing Variadic args (#311) + +## v0.9.0 + +Changes since v0.9.0-rc.3 + +- Improved non-static (scoped) userdata support +- Added `Scope::create_any_userdata()` method +- Added `Lua::set_vector_metatable()` method (`unstable` feature flag) +- Added `OwnedThread` type (`unstable` feature flag) +- Minimal Luau updated to 0.590 +- Added new option `sort_keys` to `DeserializeOptions` (`Lua::from_value()` method) +- Changed `Table::raw_len()` output type to `usize` +- Helper functions for `Value` (eg: `Value::as_number()`/`Value::as_string`/etc) +- Performance improvements + +## v0.9.0-rc.3 + +- Minimal Luau updated to 0.588 + +## v0.9.0-rc.2 + +- Added `#[derive(FromLua)]` macro to opt-in into `FromLua where T: 'static + Clone` (userdata type). +- Support vendored module mode for windows (raw-dylib linking, Rust 1.71+) +- `module` and `vendored` features are now mutually exclusive +- Use `C-unwind` ABI (Rust 1.71+) +- Changed `AsChunk` trait to support capturing wrapped Lua types + +## v0.9.0-rc.1 + +- `UserDataMethods::add_async_method()` takes `&T` instead of cloning `T` +- Implemented `PartialEq<[T]>` for tables +- Added Luau 4-dimensional vectors support (`luau-vector4` feature) +- `Table::sequence_values()` iterator no longer uses any metamethods (`Table::raw_sequence_values()` is deprecated) +- Added `Table:is_empty()` function that checks both hash and array parts +- Refactored Debug interface +- Re-exported `ffi` (`mlua-sys`) crate for easier writing of unsafe code +- Refactored Lua 5.4 warnings interface +- Take `&str` as function name in `TableExt` and `AnyUserDataExt` traits +- Added module attribule `skip_memory_check` to improve performance +- Added `AnyUserData::wrap()` to provide more easy way of creating _any_ userdata in Lua + +## v0.9.0-beta.3 + +- Added `OwnedAnyUserData::take()` +- Switch to `DeserializeOwned` +- Overwrite error context when called multiple times +- New feature flag `luau-jit` to enable (experimental) Luau codegen backend +- Set `__name` field in userdata metatable +- Added `Value::to_string()` method similar to `luaL_tolstring` +- Lua 5.4.6 +- Application data container now allows to mutably and immutably borrow different types at the same time +- Performance optimizations +- Support getting and setting environment for Lua functions. +- Added `UserDataFields::add_field()` method to add static fields to UserData + +Breaking changes: +- Require environment to be a `Table` instead of `Value` in Chunks. +- `AsChunk::env()` renamed to `AsChunk::environment()` + +## v0.9.0-beta.2 + +New features: +- Added `Thread::set_hook()` function to set hook on threads +- Added pretty print to the Debug formatting to Lua `Value` and `Table` +- ffi layer moved to `mlua-sys` crate +- Added OwnedString (unstable) + +Breaking changes: +- Refactor `HookTriggers` (make it const) + +## v0.9.0-beta.1 + +New features: +- Owned Lua types (unstable feature flag) +- New functions `Function::wrap`/`Function::wrap_mut`/`Function::wrap_async` +- `Lua::register_userdata_type()` to register a custom userdata types (without requiring `UserData` trait) +- `Lua::create_any_userdata()` +- Added `create_userdata_ref`/`create_userdata_ref_mut` for scopes +- Added `AnyUserDataExt` trait with auxiliary functions for `AnyUserData` +- Added `UserDataRef` and `UserDataRefMut` type wrapped that implement `FromLua` +- Improved error handling: + * Improved error reporting when calling Rust functions from Lua. + * Added `Error::BadArgument` to help identify bad argument position or name + * Added `ErrorContext` extension trait to attach additional context to `Error` + +Breaking changes: +- Refactored `AsChunk` trait +- `ToLua`/`ToLuaMulti` renamed to `IntoLua`/`IntoLuaMulti` +- Renamed `to_lua_err` to `into_lua_err` +- Removed `FromLua` impl for `T: UserData+Clone` +- Removed `Lua::async_scope` +- Added `&Lua` arg to Luau interrupt callback + +Other: +- Better Debug for String +- Allow deserializing values from serializable UserData using `Lua::from_value()` method +- Added `Table::clear()` method +- Added `Error::downcast_ref()` method +- Support setting memory limit for Lua 5.1/JIT/Luau +- Support setting module name in `#[lua_module(name = "...")]` macro +- Minor fixes and improvements + +## v0.8.10 + +- Update to Luau 0.590 (luau0-src to 0.7.x) +- Fix loading luau code starting with \t +- Pin lua-src and luajit-src versions + +## v0.8.9 + +- Update minimal (vendored) Lua 5.4 to 5.4.6 +- Use `lua_closethread` instead of `lua_resetthread` in vendored mode (Lua 5.4.6) +- Allow deserializing Lua null into unit (`()`) or unit struct. + +## v0.8.8 + +- Fix potential deadlock when trying to reuse dropped registry keys. +- Optimize userdata methods call when __index and fields_getters are nil + +## v0.8.7 + +- Minimum Luau updated to 0.555 (`LUAI_MAXCSTACK` limit increased to 100000) +- `_VERSION` in Luau now includes version number +- Fixed lifetime of `DebugNames` in `Debug::names()` and `DebugSource` in `Debug::source()` +- Fixed subtraction overflow when calculating index for `MultiValue::get()` + +## v0.8.6 + +- Fixed bug when recycled Registry slot can be set to Nil + +## v0.8.5 + +- Fixed potential unsoundness when using `Layout::from_size_align_unchecked` and Rust 1.65+ +- Performance optimizations around string and table creation in standalone mode +- Added fast track path to Table `get`/`set`/`len` methods without metatable +- Added new methods `push`/`pop`/`raw_push`/`raw_pop` to Table +- Fix getting caller information from `Lua::load` +- Better checks and tests when trying to modify a Luau readonly table + +## v0.8.4 + +- Minimal Luau updated to 0.548 + +## v0.8.3 + +- Close to-be-closed variables for Lua 5.4 when using call_async functions (#192) +- Fixed Lua assertion when inspecting another thread stack. (#195) +- Use more reliable way to create LuaJIT VM (which can fail if use Rust allocator on non-x86 platforms) + +## v0.8.2 + +- Performance optimizations in handling UserData +- Minimal Luau updated to 0.536 +- Fixed bug in `Function::bind` when passing empty binds and no arguments (#189) + +## v0.8.1 + +- Added `Lua::create_proxy` for accessing to UserData static fields and functions without instance +- Added `Table::to_pointer()` and `String::to_pointer()` functions +- Bugfixes and improvements (#176 #179) + +## v0.8.0 +Changes since 0.7.4 +- Luau support +- Removed C glue +- Added async support to `__index` and `__newindex` metamethods +- Added `Function::info()` to get information about functions (#149). +- Added `parking_lot` dependency under feature flag (for `UserData`) +- `Hash` implementation for Lua String +- Added `Value::to_pointer()` function +- Performance improvements + +Breaking changes: +- Refactored `AsChunk` trait (added implementation for `Path` and `PathBuf`). + +## v0.8.0-beta.5 + +- Lua sources no longer needed to build modules +- Added `__iter` metamethod for Luau +- Added `Value::to_pointer()` function +- Added `Function::coverage` for Luau to obtain coverage report +- Bugfixes and improvements (#153 #161 #168) + ## v0.8.0-beta.4 - Removed `&Lua` from `Lua::set_interrupt` as it's not safe (introduced in v0.8.0-beta.3) @@ -28,7 +483,7 @@ ## v0.8.0-beta.1 -- Roblox Luau support +- Luau support - Refactored ffi module. C glue is no longer required - Added async support to `__index` and `__newindex` metamethods @@ -141,7 +596,7 @@ Breaking changes: - [**Breaking**] Removed `AnyUserData::has_metamethod()` - Added `Thread::reset()` for luajit/lua54 to recycle threads. It's possible to attach a new function to a thread (coroutine). -- Added `chunk!` macro support to load chunks of Lua code using the Rust tokenizer and optinally capturing Rust variables. +- Added `chunk!` macro support to load chunks of Lua code using the Rust tokenizer and optionally capturing Rust variables. - Improved error reporting (`Error`'s `__tostring` method formats full stacktraces). This is useful in the module mode. ## v0.6.0-beta.1 @@ -197,7 +652,7 @@ Breaking changes: - Lua 5.4 support with `MetaMethod::Close`. - `lua53` feature is disabled by default. Now preferred Lua version have to be chosen explicitly. -- Provide safety guaraness for Lua state, which means that potenially unsafe operations, like loading C modules (using `require` or `package.loadlib`) are disabled. Equalient for the previous `Lua::new()` function is `Lua::unsafe_new()`. +- Provide safety guarantees for Lua state, which means that potentially unsafe operations, like loading C modules (using `require` or `package.loadlib`) are disabled. Equivalent to the previous `Lua::new()` function is `Lua::unsafe_new()`. - New `send` feature to require `Send`. - New `module` feature, that disables linking to Lua Core Libraries. Required for modules. - Don't allow `'callback` outlive `'lua` in `Lua::create_function()` to fix [the unsoundness](tests/compile/static_callback_args.rs). diff --git a/Cargo.toml b/Cargo.toml index fb8a957b..1488205b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,107 +1,124 @@ [package] name = "mlua" -version = "0.8.0-beta.4" # remember to update html_root_url and mlua_derive -authors = ["Aleksandr Orlenko ", "kyren "] -edition = "2018" -repository = "https://github.com/khvzak/mlua" +version = "0.12.0-rc.1" # remember to update mlua_derive +authors = ["Aleksandr Orlenko ", "kyren "] +rust-version = "1.88" +edition = "2024" +repository = "https://github.com/mlua-rs/mlua" documentation = "https://docs.rs/mlua" readme = "README.md" -keywords = ["lua", "luajit", "async", "futures", "scripting"] +keywords = ["lua", "luajit", "luau", "async", "scripting"] categories = ["api-bindings", "asynchronous"] license = "MIT" -links = "lua" -build = "build/main.rs" description = """ -High level bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau +High level bindings to Lua 5.5/5.4/5.3/5.2/5.1 (including LuaJIT) and Luau with async/await features and support of writing native Lua modules in Rust. """ [package.metadata.docs.rs] -features = ["lua54", "vendored", "async", "send", "serialize", "macros", "parking_lot"] +features = ["lua55", "vendored", "async", "send", "serde", "macros"] rustdoc-args = ["--cfg", "docsrs"] [workspace] members = [ "mlua_derive", + "mlua-sys", ] [features] -lua54 = [] -lua53 = [] -lua52 = [] -lua51 = [] -luajit = [] -luajit52 = ["luajit"] -luau = ["luau0-src"] -vendored = ["lua-src", "luajit-src"] -module = ["mlua_derive"] -async = ["futures-core", "futures-task", "futures-util"] -send = [] -serialize = ["serde", "erased-serde"] +lua55 = ["ffi/lua55"] +lua54 = ["ffi/lua54"] +lua53 = ["ffi/lua53"] +lua52 = ["ffi/lua52"] +lua51 = ["ffi/lua51"] +luajit = ["ffi/luajit"] +luajit52 = ["luajit", "ffi/luajit52"] +luau = ["ffi/luau"] +luau-jit = ["luau", "ffi/luau-codegen"] +luau-vector4 = ["luau", "ffi/luau-vector4"] +vendored = ["ffi/vendored"] +module = ["mlua_derive", "ffi/module"] +async = ["dep:futures-util"] +send = ["error-send"] +error-send = [] +serde = ["dep:serde", "dep:erased-serde", "dep:serde-value", "bstr/serde"] macros = ["mlua_derive/macros"] +anyhow = ["dep:anyhow", "error-send"] +userdata-wrappers = ["parking_lot/send_guard"] + +# deprecated features +serialize = ["serde"] [dependencies] -mlua_derive = { version = "=0.8.0-beta.1", optional = true, path = "mlua_derive" } -bstr = { version = "0.2", features = ["std"], default_features = false } -once_cell = { version = "1.0" } +mlua_derive = { version = "=0.11.0", optional = true, path = "mlua_derive" } +bstr = { version = "1.0", features = ["std"], default-features = false } +either = "1.0" num-traits = { version = "0.2.14" } -rustc-hash = "1.0" -futures-core = { version = "0.3.5", optional = true } -futures-task = { version = "0.3.5", optional = true } -futures-util = { version = "0.3.5", optional = true } +rustc-hash = "2.0" +futures-util = { version = "0.3", optional = true, default-features = false, features = ["std"] } serde = { version = "1.0", optional = true } -erased-serde = { version = "0.3", optional = true } -parking_lot = { version = "0.12", optional = true } +erased-serde = { version = "0.4", optional = true } +serde-value = { version = "0.7", optional = true } +parking_lot = { version = "0.12", features = ["arc_lock"] } +anyhow = { version = "1.0", optional = true } +libc = "0.2" -[build-dependencies] -cc = { version = "1.0" } -pkg-config = { version = "0.3.17" } -lua-src = { version = ">= 544.0.0, < 550.0.0", optional = true } -luajit-src = { version = ">= 210.3.1, < 220.0.0", optional = true } -luau0-src = { version = "0.3", optional = true } +ffi = { package = "mlua-sys", version = "0.11.0-rc.1", path = "mlua-sys" } [dev-dependencies] -rustyline = "9.0" -criterion = { version = "0.3.4", features = ["html_reports", "async_tokio"] } trybuild = "1.0" -futures = "0.3.5" -hyper = { version = "0.14", features = ["client", "server"] } -reqwest = { version = "0.11", features = ["json"] } -tokio = { version = "1.0", features = ["full"] } -futures-timer = "3.0" +tokio = { version = "1.0", features = ["macros", "rt", "time"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["arbitrary_precision"] } maplit = "1.0" +static_assertions = "1.0" + +[target.'cfg(not(target_family = "wasm"))'.dev-dependencies] +hyper = { version = "1.2", features = ["full"] } +hyper-util = { version = "0.1.3", features = ["full"] } +http-body-util = "0.1.1" +reqwest = { version = "0.13", features = ["json"] } tempfile = "3" +criterion = { version = "0.8", features = ["async_tokio"] } +rustyline = "18.0" +tokio = { version = "1.0", features = ["full"] } + +[lints.rust] +unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] } [[bench]] name = "benchmark" harness = false required-features = ["async"] +[[bench]] +name = "serde" +harness = false +required-features = ["serde"] + [[example]] name = "async_http_client" required-features = ["async", "macros"] [[example]] name = "async_http_reqwest" -required-features = ["async", "serialize", "macros"] +required-features = ["async", "serde", "macros"] [[example]] name = "async_http_server" -required-features = ["async", "macros"] +required-features = ["async", "macros", "send"] [[example]] name = "async_tcp_server" -required-features = ["async", "macros"] +required-features = ["async", "macros", "send"] [[example]] name = "guided_tour" required-features = ["macros"] [[example]] -name = "serialize" -required-features = ["serialize"] +name = "serde" +required-features = ["serde"] [[example]] name = "userdata" diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 00000000..eb018184 --- /dev/null +++ b/FAQ.md @@ -0,0 +1,21 @@ +# mlua FAQ + +This file is for general questions that don't fit into the README or crate docs. + +## Loading a C module fails with error `undefined symbol: lua_xxx`. How to fix? + +Add the following rustflags to your [.cargo/config](http://doc.crates.io/config.html) in order to properly export Lua symbols: + +```toml +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "link-args=-rdynamic"] + +[target.x86_64-apple-darwin] +rustflags = ["-C", "link-args=-rdynamic"] +``` + +## I want to add support for a Lua VM fork to mlua. Do you accept pull requests? + +Adding new feature flag to support a Lua VM fork is a major step that requires huge effort to maintain it. +Regular updates, testing, checking compatibility, etc. +That's why I don't plan to support new Lua VM forks or other languages in mlua. diff --git a/README.md b/README.md index 35ad9926..b253b571 100644 --- a/README.md +++ b/README.md @@ -1,69 +1,82 @@ # mlua [![Build Status]][github-actions] [![Latest Version]][crates.io] [![API Documentation]][docs.rs] [![Coverage Status]][codecov.io] ![MSRV] -[Build Status]: https://github.com/khvzak/mlua/workflows/CI/badge.svg -[github-actions]: https://github.com/khvzak/mlua/actions +[Build Status]: https://github.com/mlua-rs/mlua/workflows/CI/badge.svg +[github-actions]: https://github.com/mlua-rs/mlua/actions [Latest Version]: https://img.shields.io/crates/v/mlua.svg [crates.io]: https://crates.io/crates/mlua [API Documentation]: https://docs.rs/mlua/badge.svg [docs.rs]: https://docs.rs/mlua -[Coverage Status]: https://codecov.io/gh/khvzak/mlua/branch/master/graph/badge.svg?token=99339FS1CG -[codecov.io]: https://codecov.io/gh/khvzak/mlua -[MSRV]: https://img.shields.io/badge/rust-1.53+-brightgreen.svg?&logo=rust +[Coverage Status]: https://codecov.io/gh/mlua-rs/mlua/branch/main/graph/badge.svg?token=99339FS1CG +[codecov.io]: https://codecov.io/gh/mlua-rs/mlua +[MSRV]: https://img.shields.io/badge/rust-1.79+-brightgreen.svg?&logo=rust -[Guided Tour](examples/guided_tour.rs) +[Guided Tour] | [Benchmarks] | [FAQ] -`mlua` is bindings to [Lua](https://www.lua.org) programming language for Rust with a goal to provide -_safe_ (as far as it's possible), high level, easy to use, practical and flexible API. +[Guided Tour]: examples/guided_tour.rs +[Benchmarks]: https://github.com/khvzak/script-bench-rs +[FAQ]: FAQ.md -Started as `rlua` fork, `mlua` supports Lua 5.4, 5.3, 5.2, 5.1 (including LuaJIT) and [Roblox Luau] and allows to write native Lua modules in Rust as well as use Lua in a standalone mode. +## The main branch is the development version of `mlua`. Please see the [v0.11](https://github.com/mlua-rs/mlua/tree/v0.11) branch for the stable versions of `mlua`. -`mlua` tested on Windows/macOS/Linux including module mode in [GitHub Actions] on `x86_64` platform and cross-compilation to `aarch64` (other targets are also supported). +`mlua` is a set of bindings to the [Lua](https://www.lua.org) programming language for Rust with a goal of providing a +_safe_ (as much as possible), high level, easy to use, practical and flexible API. -[GitHub Actions]: https://github.com/khvzak/mlua/actions -[Roblox Luau]: https://luau-lang.org +Started as an `rlua` fork, `mlua` supports Lua 5.5, 5.4, 5.3, 5.2, 5.1 (including LuaJIT) and [Luau] and allows writing native Lua modules in Rust as well as using Lua in a standalone mode. + +`mlua` is tested on Windows/macOS/Linux including module mode in [GitHub Actions] on `x86_64` platforms and cross-compilation to `aarch64` (other targets are also supported). + +WebAssembly (WASM) is supported through the `wasm32-unknown-emscripten` target for all Lua/Luau versions excluding JIT. + +[GitHub Actions]: https://github.com/mlua-rs/mlua/actions +[Luau]: https://luau.org ## Usage ### Feature flags -`mlua` uses feature flags to reduce the amount of dependencies, compiled code and allow to choose only required set of features. +`mlua` uses feature flags to reduce the number of dependencies and compiled code, and allow choosing only the required set of features. Below is a list of the available feature flags. By default `mlua` does not enable any features. -* `lua54`: activate Lua [5.4] support -* `lua53`: activate Lua [5.3] support -* `lua52`: activate Lua [5.2] support -* `lua51`: activate Lua [5.1] support -* `luajit`: activate [LuaJIT] support -* `luajit52`: activate [LuaJIT] support with partial compatibility with Lua 5.2 -* `luau`: activate [Luau] support (auto vendored mode) -* `vendored`: build static Lua(JIT) library from sources during `mlua` compilation using [lua-src] or [luajit-src] crates +* `lua55`: enable Lua [5.5] support +* `lua54`: enable Lua [5.4] support +* `lua53`: enable Lua [5.3] support +* `lua52`: enable Lua [5.2] support +* `lua51`: enable Lua [5.1] support +* `luajit`: enable [LuaJIT] support +* `luajit52`: enable [LuaJIT] support with partial compatibility with Lua 5.2 +* `luau`: enable [Luau] support (auto vendored mode) +* `luau-jit`: enable [Luau] support with JIT backend. +* `luau-vector4`: enable [Luau] support with 4-dimensional vector. +* `vendored`: build static Lua(JIT) libraries from sources during `mlua` compilation using [lua-src] or [luajit-src] * `module`: enable module mode (building loadable `cdylib` library for Lua) * `async`: enable async/await support (any executor can be used, eg. [tokio] or [async-std]) -* `send`: make `mlua::Lua` transferable across thread boundaries (adds [`Send`] requirement to `mlua::Function` and `mlua::UserData`) -* `serialize`: add serialization and deserialization support to `mlua` types using [serde] framework +* `send`: make `mlua::Lua: Send + Sync` (adds [`Send`] requirement to `mlua::Function` and `mlua::UserData`) +* `error-send`: make `mlua:Error: Send + Sync` +* `serde`: add serialization and deserialization support to `mlua` types using [serde] * `macros`: enable procedural macros (such as `chunk!`) -* `parking_lot`: support UserData types wrapped in [parking_lot]'s primitives (`Arc` and `Arc`) +* `anyhow`: enable `anyhow::Error` conversion into Lua +* `userdata-wrappers`: opt into `impl UserData` for `Rc`/`Arc`/`Rc>`/`Arc>` where `T: UserData` +[5.5]: https://www.lua.org/manual/5.5/manual.html [5.4]: https://www.lua.org/manual/5.4/manual.html [5.3]: https://www.lua.org/manual/5.3/manual.html [5.2]: https://www.lua.org/manual/5.2/manual.html [5.1]: https://www.lua.org/manual/5.1/manual.html [LuaJIT]: https://luajit.org/ -[Luau]: https://github.com/Roblox/luau -[lua-src]: https://github.com/khvzak/lua-src-rs -[luajit-src]: https://github.com/khvzak/luajit-src-rs +[Luau]: https://github.com/luau-lang/luau +[lua-src]: https://github.com/mlua-rs/lua-src-rs +[luajit-src]: https://github.com/mlua-rs/luajit-src-rs [tokio]: https://github.com/tokio-rs/tokio [async-std]: https://github.com/async-rs/async-std [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html [serde]: https://github.com/serde-rs/serde -[parking_lot]: https://github.com/Amanieu/parking_lot ### Async/await support `mlua` supports async/await for all Lua versions including Luau. -This works using Lua [coroutines](https://www.lua.org/manual/5.3/manual.html#2.6) and require running [Thread](https://docs.rs/mlua/latest/mlua/struct.Thread.html) along with enabling `feature = "async"` in `Cargo.toml`. +This works using Lua [coroutines](https://www.lua.org/manual/5.3/manual.html#2.6) and requires running [Thread](https://docs.rs/mlua/latest/mlua/struct.Thread.html) along with enabling `feature = "async"` in `Cargo.toml`. **Examples**: - [HTTP Client](examples/async_http_client.rs) @@ -71,11 +84,25 @@ This works using Lua [coroutines](https://www.lua.org/manual/5.3/manual.html#2.6 - [HTTP Server](examples/async_http_server.rs) - [TCP Server](examples/async_tcp_server.rs) -### Serialization (serde) support -With `serialize` feature flag enabled, `mlua` allows you to serialize/deserialize any type that implements [`serde::Serialize`] and [`serde::Deserialize`] into/from [`mlua::Value`]. In addition `mlua` provides [`serde::Serialize`] trait implementation for it (including `UserData` support). +**shell command examples**: +```shell +# async http client (hyper) +cargo run --example async_http_client --features=lua54,async,macros + +# async http client (reqwest) +cargo run --example async_http_reqwest --features=lua54,async,macros,serde + +# async http server +cargo run --example async_http_server --features=lua54,async,macros,send +curl -v http://localhost:3000 +``` + +### Serde support + +With the `serde` feature flag enabled, `mlua` allows you to serialize/deserialize any type that implements [`serde::Serialize`] and [`serde::Deserialize`] into/from [`mlua::Value`]. In addition, `mlua` provides the [`serde::Serialize`] trait implementation for `mlua::Value` (including `UserData` support). -[Example](examples/serialize.rs) +[Example](examples/serde.rs) [`serde::Serialize`]: https://docs.serde.rs/serde/ser/trait.Serialize.html [`serde::Deserialize`]: https://docs.serde.rs/serde/de/trait.Deserialize.html @@ -85,28 +112,28 @@ With `serialize` feature flag enabled, `mlua` allows you to serialize/deserializ You have to enable one of the features: `lua54`, `lua53`, `lua52`, `lua51`, `luajit(52)` or `luau`, according to the chosen Lua version. -By default `mlua` uses `pkg-config` tool to find lua includes and libraries for the chosen Lua version. -In most cases it works as desired, although sometimes could be more preferable to use a custom lua library. -To achieve this, mlua supports `LUA_INC`, `LUA_LIB`, `LUA_LIB_NAME` and `LUA_LINK` environment variables. +By default `mlua` uses `pkg-config` to find Lua includes and libraries for the chosen Lua version. +In most cases it works as desired, although sometimes it may be preferable to use a custom Lua library. +To achieve this, mlua supports the `LUA_LIB`, `LUA_LIB_NAME` and `LUA_LINK` environment variables. `LUA_LINK` is optional and may be `dylib` (a dynamic library) or `static` (a static library, `.a` archive). -An example how to use them: +An example of how to use them: ``` sh -my_project $ LUA_INC=$HOME/tmp/lua-5.2.4/src LUA_LIB=$HOME/tmp/lua-5.2.4/src LUA_LIB_NAME=lua LUA_LINK=static cargo build +my_project $ LUA_LIB=$HOME/tmp/lua-5.2.4/src LUA_LIB_NAME=lua LUA_LINK=static cargo build ``` -`mlua` also supports vendored lua/luajit using the auxiliary crates [lua-src](https://crates.io/crates/lua-src) and +`mlua` also supports vendored Lua/LuaJIT using the auxiliary crates [lua-src](https://crates.io/crates/lua-src) and [luajit-src](https://crates.io/crates/luajit-src). -Just enable the `vendored` feature and cargo will automatically build and link specified lua/luajit version. This is the easiest way to get started with `mlua`. +Just enable the `vendored` feature and cargo will automatically build and link the specified Lua/LuaJIT version. This is the easiest way to get started with `mlua`. ### Standalone mode -In a standalone mode `mlua` allows to add to your application scripting support with a gently configured Lua runtime to ensure safety and soundness. +In standalone mode, `mlua` allows adding scripting support to your application with a properly configured Lua runtime to ensure safety and soundness. -Add to `Cargo.toml` : +Add to `Cargo.toml`: ``` toml [dependencies] -mlua = { version = "0.8.0-beta.4", features = ["lua54", "vendored"] } +mlua = { version = "0.11", features = ["lua54", "vendored"] } ``` `main.rs` @@ -130,21 +157,21 @@ fn main() -> LuaResult<()> { ``` ### Module mode -In a module mode `mlua` allows to create a compiled Lua module that can be loaded from Lua code using [`require`](https://www.lua.org/manual/5.4/manual.html#pdf-require). In this case `mlua` uses an external Lua runtime which could lead to potential unsafety due to unpredictability of the Lua environment and usage of libraries such as [`debug`](https://www.lua.org/manual/5.4/manual.html#6.10). +In module mode, `mlua` allows creating a compiled Lua module that can be loaded from Lua code using [`require`](https://www.lua.org/manual/5.4/manual.html#pdf-require). In this case `mlua` uses an external Lua runtime which could lead to potential unsafety due to the unpredictability of the Lua environment and usage of libraries such as [`debug`](https://www.lua.org/manual/5.4/manual.html#6.10). [Example](examples/module) -Add to `Cargo.toml` : +Add to `Cargo.toml`: ``` toml [lib] crate-type = ["cdylib"] [dependencies] -mlua = { version = "0.8.0-beta.4", features = ["lua54", "vendored", "module"] } +mlua = { version = "0.11", features = ["lua54", "module"] } ``` -`lib.rs` : +`lib.rs`: ``` rust use mlua::prelude::*; @@ -171,7 +198,7 @@ $ lua5.4 -e 'require("my_module").hello("world")' hello, world! ``` -On macOS, you need to set additional linker arguments. One option is to compile with `cargo rustc --release -- -C link-arg=-undefined -C link-arg=dynamic_lookup`, the other is to create a `.cargo/config` with the following content: +On macOS, you need to set additional linker arguments. One option is to compile with `cargo rustc --release -- -C link-arg=-undefined -C link-arg=dynamic_lookup`, the other is to create a `.cargo/config.toml` with the following content: ``` toml [target.x86_64-apple-darwin] rustflags = [ @@ -186,29 +213,31 @@ rustflags = [ ] ``` On Linux you can build modules normally with `cargo build --release`. -Vendored and non-vendored builds are supported for these OS. -On Windows `vendored` mode for modules is not supported since you need to link to a Lua dll. -Easiest way is to use either MinGW64 (as part of [MSYS2](https://github.com/msys2/msys2) package) with `pkg-config` or -MSVC with `LUA_INC` / `LUA_LIB` / `LUA_LIB_NAME` environment variables. +On Windows the target module will be linked with the `lua5x.dll` library (depending on your feature flags). +Your main application should provide this library. -More details about compiling and linking Lua modules can be found on the [Building Modules](http://lua-users.org/wiki/BuildingModules) page. +Module builds don't require Lua binaries or headers to be installed on the system. ### Publishing to luarocks.org -There is a LuaRocks build backend for mlua modules [`luarocks-build-rust-mlua`]. +There is a LuaRocks build backend for mlua modules: [`luarocks-build-rust-mlua`]. Modules written in Rust and published to luarocks: +- [`decasify`](https://github.com/alerque/decasify) - [`lua-ryaml`](https://github.com/khvzak/lua-ryaml) +- [`tiktoken_core`](https://github.com/gptlang/lua-tiktoken) +- [`toml-edit`](https://github.com/vhyrro/toml-edit.lua) +- [`typst-lua`](https://github.com/rousbound/typst-lua) [`luarocks-build-rust-mlua`]: https://luarocks.org/modules/khvzak/luarocks-build-rust-mlua ## Safety -One of the `mlua` goals is to provide *safe* API between Rust and Lua. -Every place where the Lua C API may trigger an error longjmp in any way is protected by `lua_pcall`, -and the user of the library is protected from directly interacting with unsafe things like the Lua stack, -and there is overhead associated with this safety. +One of `mlua`'s goals is to provide a *safe* API between Rust and Lua. +Every place where the Lua C API may trigger an error longjmp is protected by `lua_pcall`, +and the user of the library is protected from directly interacting with unsafe things like the Lua stack. +There is overhead associated with this safety. Unfortunately, `mlua` does not provide absolute safety even without using `unsafe` . This library contains a huge amount of unsafe code. There are almost certainly bugs still lurking in this library! @@ -216,8 +245,8 @@ It is surprisingly, fiendishly difficult to use the Lua C API without the potent ## Panic handling -`mlua` wraps panics that are generated inside Rust callbacks in a regular Lua error. Panics could be -resumed then by returning or propagating the Lua error to Rust code. +`mlua` wraps panics that are generated inside Rust callbacks in a regular Lua error. Panics can then be +resumed by returning or propagating the Lua error to Rust code. For example: ``` rust @@ -236,16 +265,16 @@ let _ = lua.load(r#" unreachable!() ``` -Optionally `mlua` can disable Rust panics catching in Lua via `pcall`/`xpcall` and automatically resume +Optionally, `mlua` can disable Rust panic catching in Lua via `pcall`/`xpcall` and automatically resume them across the Lua API boundary. This is controlled via `LuaOptions` and done by wrapping the Lua `pcall`/`xpcall` -functions on a way to prevent catching errors that are wrapped Rust panics. +functions to prevent catching errors that are wrapped Rust panics. `mlua` should also be panic safe in another way as well, which is that any `Lua` instances or handles -remains usable after a user generated panic, and such panics should not break internal invariants or +remain usable after a user generated panic, and such panics should not break internal invariants or leak Lua stack space. This is mostly important to safely use `mlua` types in Drop impls, as you should not be using panics for general error handling. -Below is a list of `mlua` behaviors that should be considered a bug. +Below is a list of `mlua` behaviors that should be considered bugs. If you encounter them, a bug report would be very welcome: + If you can cause UB with `mlua` without typing the word "unsafe", this is a bug. @@ -256,6 +285,14 @@ If you encounter them, a bug report would be very welcome: + If you detect that, after catching a panic or during a Drop triggered from a panic, a `Lua` or handle method is triggering other bugs or there is a Lua stack space leak, this is a bug. `mlua` instances are supposed to remain fully usable in the face of user generated panics. This guarantee does not extend to panics marked with "mlua internal error" simply because that is already indicative of a separate bug. +## Sandboxing + +Please check the [Luau Sandboxing] page if you are interested in running untrusted Lua scripts in a controlled environment. + +`mlua` provides the `Lua::sandbox` method for enabling sandbox mode (Luau only). + +[Luau Sandboxing]: https://luau.org/sandbox + ## License -This project is licensed under the [MIT license](LICENSE) +This project is licensed under the [MIT license](LICENSE). diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 346ee7a8..3e3bcb96 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -1,5 +1,7 @@ -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; + +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use tokio::runtime::Runtime; use tokio::task; @@ -10,10 +12,10 @@ fn collect_gc_twice(lua: &Lua) { lua.gc_collect().unwrap(); } -fn create_table(c: &mut Criterion) { +fn table_create_empty(c: &mut Criterion) { let lua = Lua::new(); - c.bench_function("create [table empty]", |b| { + c.bench_function("table [create empty]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { @@ -24,34 +26,101 @@ fn create_table(c: &mut Criterion) { }); } -fn create_array(c: &mut Criterion) { +fn table_create_array(c: &mut Criterion) { let lua = Lua::new(); - c.bench_function("create [array] 10", |b| { + c.bench_function("table [create array]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { - let table = lua.create_table().unwrap(); - for i in 1..=10 { - table.set(i, i).unwrap(); - } + lua.create_sequence_from(1..=10).unwrap(); }, BatchSize::SmallInput, ); }); } -fn create_string_table(c: &mut Criterion) { +fn table_create_hash(c: &mut Criterion) { let lua = Lua::new(); - c.bench_function("create [table string] 10", |b| { + c.bench_function("table [create hash]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { - let table = lua.create_table().unwrap(); - for &s in &["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] { - let s = lua.create_string(s).unwrap(); - table.set(s.clone(), s).unwrap(); + lua.create_table_from( + ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] + .into_iter() + .map(|s| (s, s)), + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }); +} + +fn table_get_set(c: &mut Criterion) { + let lua = Lua::new(); + + c.bench_function("table [get and set]", |b| { + b.iter_batched( + || { + collect_gc_twice(&lua); + lua.create_table().unwrap() + }, + |table| { + for (i, s) in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] + .into_iter() + .enumerate() + { + table.raw_set(s, i).unwrap(); + assert_eq!(table.raw_get::(s).unwrap(), i); + } + }, + BatchSize::SmallInput, + ); + }); +} + +fn table_traversal_pairs(c: &mut Criterion) { + let lua = Lua::new(); + + c.bench_function("table [traversal pairs]", |b| { + b.iter_batched( + || lua.globals(), + |globals| { + for kv in globals.pairs::() { + let (_k, _v) = kv.unwrap(); + } + }, + BatchSize::SmallInput, + ); + }); +} + +fn table_traversal_for_each(c: &mut Criterion) { + let lua = Lua::new(); + + c.bench_function("table [traversal for_each]", |b| { + b.iter_batched( + || lua.globals(), + |globals| globals.for_each::(|_k, _v| Ok(())), + BatchSize::SmallInput, + ); + }); +} + +fn table_traversal_sequence(c: &mut Criterion) { + let lua = Lua::new(); + + let table = lua.create_sequence_from(1..1000).unwrap(); + + c.bench_function("table [traversal sequence]", |b| { + b.iter_batched( + || table.clone(), + |table| { + for v in table.sequence_values::() { + let _i = v.unwrap(); } }, BatchSize::SmallInput, @@ -59,234 +128,309 @@ fn create_string_table(c: &mut Criterion) { }); } -fn create_function(c: &mut Criterion) { +fn table_ref_clone(c: &mut Criterion) { let lua = Lua::new(); - c.bench_function("create [function] 10", |b| { + let t = lua.create_table().unwrap(); + + c.bench_function("table [ref clone]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { - for i in 0..10 { - lua.create_function(move |_, ()| Ok(i)).unwrap(); - } + let _t2 = t.clone(); + }, + BatchSize::SmallInput, + ); + }); +} + +fn function_create(c: &mut Criterion) { + let lua = Lua::new(); + + c.bench_function("function [create Rust]", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + lua.create_function(|_, ()| Ok(123)).unwrap(); + }, + BatchSize::SmallInput, + ); + }); +} + +fn function_call_sum(c: &mut Criterion) { + let lua = Lua::new(); + + let sum = lua + .create_function(|_, (a, b, c): (i64, i64, i64)| Ok(a + b - c)) + .unwrap(); + + c.bench_function("function [call Rust sum]", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + assert_eq!(sum.call::((10, 20, 30)).unwrap(), 0); + }, + BatchSize::SmallInput, + ); + }); +} + +fn function_call_lua_sum(c: &mut Criterion) { + let lua = Lua::new(); + + let sum = lua + .load("function(a, b, c) return a + b - c end") + .eval::() + .unwrap(); + + c.bench_function("function [call Lua sum]", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + assert_eq!(sum.call::((10, 20, 30)).unwrap(), 0); }, BatchSize::SmallInput, ); }); } -fn call_lua_function(c: &mut Criterion) { +fn function_call_concat(c: &mut Criterion) { let lua = Lua::new(); - c.bench_function("call Lua function [sum] 3 10", |b| { - b.iter_batched_ref( + let concat = lua + .create_function(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))) + .unwrap(); + let i = AtomicUsize::new(0); + + c.bench_function("function [call Rust concat string]", |b| { + b.iter_batched( || { collect_gc_twice(&lua); - lua.load("function(a, b, c) return a + b + c end") - .eval::() - .unwrap() + i.fetch_add(1, Ordering::Relaxed) }, - |function| { - for i in 0..10 { - let _result: i64 = function.call((i, i + 1, i + 2)).unwrap(); - } + |i| { + assert_eq!(concat.call::(("num:", i)).unwrap(), format!("num:{i}")); }, BatchSize::SmallInput, ); }); } -fn call_sum_callback(c: &mut Criterion) { +fn function_call_lua_concat(c: &mut Criterion) { let lua = Lua::new(); - let callback = lua - .create_function(|_, (a, b, c): (i64, i64, i64)| Ok(a + b + c)) + + let concat = lua + .load("function(a, b) return a..b end") + .eval::() .unwrap(); - lua.globals().set("callback", callback).unwrap(); + let i = AtomicUsize::new(0); - c.bench_function("call Rust callback [sum] 3 10", |b| { - b.iter_batched_ref( + c.bench_function("function [call Lua concat string]", |b| { + b.iter_batched( || { collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do callback(i, i+1, i+2) end end") - .eval::() - .unwrap() + i.fetch_add(1, Ordering::Relaxed) }, - |function| { - function.call::<_, ()>(()).unwrap(); + |i| { + assert_eq!(concat.call::(("num:", i)).unwrap(), format!("num:{i}")); }, BatchSize::SmallInput, ); }); } -fn call_async_sum_callback(c: &mut Criterion) { - let options = LuaOptions::new().thread_cache_size(1024); +fn function_async_call_sum(c: &mut Criterion) { + let options = LuaOptions::new().thread_pool_size(1024); let lua = Lua::new_with(LuaStdLib::ALL_SAFE, options).unwrap(); - let callback = lua + + let sum = lua .create_async_function(|_, (a, b, c): (i64, i64, i64)| async move { task::yield_now().await; - Ok(a + b + c) + Ok(a + b - c) }) .unwrap(); - lua.globals().set("callback", callback).unwrap(); - c.bench_function("call async Rust callback [sum] 3 10", |b| { + c.bench_function("function [async call Rust sum]", |b| { let rt = Runtime::new().unwrap(); b.to_async(rt).iter_batched( - || { - collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do callback(i, i+1, i+2) end end") - .eval::() - .unwrap() - }, - |function| async move { - function.call_async::<_, ()>(()).await.unwrap(); + || collect_gc_twice(&lua), + |_| async { + assert_eq!(sum.call_async::((10, 20, 30)).await.unwrap(), 0); }, BatchSize::SmallInput, ); }); } -fn call_concat_callback(c: &mut Criterion) { +fn registry_value_create(c: &mut Criterion) { let lua = Lua::new(); - let callback = lua - .create_function(|_, (a, b): (LuaString, LuaString)| { - Ok(format!("{}{}", a.to_str()?, b.to_str()?)) - }) - .unwrap(); - lua.globals().set("callback", callback).unwrap(); + lua.gc_stop(); - c.bench_function("call Rust callback [concat string] 10", |b| { - b.iter_batched_ref( - || { - collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do callback('a', tostring(i)) end end") - .eval::() - .unwrap() - }, - |function| { - function.call::<_, ()>(()).unwrap(); - }, + c.bench_function("registry value [create]", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| lua.create_registry_value("hello").unwrap(), BatchSize::SmallInput, ); }); } -fn create_registry_values(c: &mut Criterion) { +fn registry_value_get(c: &mut Criterion) { let lua = Lua::new(); + lua.gc_stop(); + + let value = lua.create_registry_value("hello").unwrap(); - c.bench_function("create [registry value] 10", |b| { + c.bench_function("registry value [get]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { - for _ in 0..10 { - lua.create_registry_value(lua.pack(true).unwrap()).unwrap(); - } - lua.expire_registry_values(); + assert_eq!(lua.registry_value::(&value).unwrap(), "hello"); }, BatchSize::SmallInput, ); }); } -fn create_userdata(c: &mut Criterion) { - struct UserData(i64); +fn userdata_create(c: &mut Criterion) { + struct UserData(#[allow(unused)] i64); impl LuaUserData for UserData {} let lua = Lua::new(); - c.bench_function("create [table userdata] 10", |b| { + c.bench_function("userdata [create]", |b| { b.iter_batched( || collect_gc_twice(&lua), |_| { - let table: LuaTable = lua.create_table().unwrap(); - for i in 1..11 { - table.set(i, UserData(i)).unwrap(); - } + lua.create_userdata(UserData(123)).unwrap(); + }, + BatchSize::SmallInput, + ); + }); +} + +fn userdata_call_index(c: &mut Criterion) { + struct UserData(#[allow(unused)] i64); + impl LuaUserData for UserData { + fn add_methods>(methods: &mut M) { + methods.add_meta_method(LuaMetaMethod::Index, move |_, _, key: LuaString| Ok(key)); + } + } + + let lua = Lua::new(); + let ud = lua.create_userdata(UserData(123)).unwrap(); + let index = lua + .load("function(ud) return ud.test end") + .eval::() + .unwrap(); + + c.bench_function("userdata [call index]", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + assert_eq!(index.call::(&ud).unwrap(), "test"); }, BatchSize::SmallInput, ); }); } -fn call_userdata_index(c: &mut Criterion) { +fn userdata_call_method(c: &mut Criterion) { struct UserData(i64); impl LuaUserData for UserData { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_meta_method(LuaMetaMethod::Index, move |_, _, index: String| Ok(index)); + fn add_methods>(methods: &mut M) { + methods.add_method("add", |_, this, i: i64| Ok(this.0 + i)); } } let lua = Lua::new(); - lua.globals().set("userdata", UserData(10)).unwrap(); + let ud = lua.create_userdata(UserData(123)).unwrap(); + let method = lua + .load("function(ud, i) return ud:add(i) end") + .eval::() + .unwrap(); + let i = AtomicUsize::new(0); - c.bench_function("call [userdata index] 10", |b| { - b.iter_batched_ref( + c.bench_function("userdata [call method]", |b| { + b.iter_batched( || { collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do local v = userdata.test end end") - .eval::() - .unwrap() + i.fetch_add(1, Ordering::Relaxed) }, - |function| { - function.call::<_, ()>(()).unwrap(); + |i| { + assert_eq!(method.call::((&ud, i)).unwrap(), 123 + i); }, BatchSize::SmallInput, ); }); } -fn call_userdata_method(c: &mut Criterion) { - struct UserData(i64); +// A userdata method call that goes through an implicit `__index` function +fn userdata_call_method_complex(c: &mut Criterion) { + struct UserData(u64); impl LuaUserData for UserData { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("method", |_, this, ()| Ok(this.0)); + fn register(registry: &mut LuaUserDataRegistry) { + registry.add_field_method_get("val", |_, this| Ok(this.0)); + registry.add_method_mut("inc_by", |_, this, by: u64| { + this.0 += by; + Ok(this.0) + }); + + #[cfg(feature = "luau")] + registry.enable_namecall(); } } let lua = Lua::new(); - lua.globals().set("userdata", UserData(10)).unwrap(); + let ud = lua.create_userdata(UserData(0)).unwrap(); + let inc_by = lua + .load("function(ud, s) return ud:inc_by(s) end") + .eval::() + .unwrap(); - c.bench_function("call [userdata method] 10", |b| { - b.iter_batched_ref( + c.bench_function("userdata [call method complex]", |b| { + b.iter_batched( || { collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do userdata:method() end end") - .eval::() - .unwrap() }, - |function| { - function.call::<_, ()>(()).unwrap(); + |_| { + inc_by.call::<()>((&ud, 1)).unwrap(); }, BatchSize::SmallInput, ); }); } -fn call_async_userdata_method(c: &mut Criterion) { - #[derive(Clone, Copy)] +fn userdata_async_call_method(c: &mut Criterion) { struct UserData(i64); impl LuaUserData for UserData { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("method", |_, this, ()| async move { Ok(this.0) }); + fn add_methods>(methods: &mut M) { + methods.add_async_method("add", |_, this, i: i64| async move { + task::yield_now().await; + Ok(this.0 + i) + }); } } - let options = LuaOptions::new().thread_cache_size(1024); + let options = LuaOptions::new().thread_pool_size(1024); let lua = Lua::new_with(LuaStdLib::ALL_SAFE, options).unwrap(); - lua.globals().set("userdata", UserData(10)).unwrap(); + let ud = lua.create_userdata(UserData(123)).unwrap(); + let method = lua + .load("function(ud, i) return ud:add(i) end") + .eval::() + .unwrap(); + let i = AtomicUsize::new(0); - c.bench_function("call async [userdata method] 10", |b| { + c.bench_function("userdata [async call method] 10", |b| { let rt = Runtime::new().unwrap(); b.to_async(rt).iter_batched( || { collect_gc_twice(&lua); - lua.load("function() for i = 1,10 do userdata:method() end end") - .eval::() - .unwrap() + (method.clone(), ud.clone(), i.fetch_add(1, Ordering::Relaxed)) }, - |function| async move { - function.call_async::<_, ()>(()).await.unwrap(); + |(method, ud, i)| async move { + assert_eq!(method.call_async::((ud, i)).await.unwrap(), 123 + i); }, BatchSize::SmallInput, ); @@ -296,23 +440,34 @@ fn call_async_userdata_method(c: &mut Criterion) { criterion_group! { name = benches; config = Criterion::default() - .sample_size(300) + .sample_size(500) .measurement_time(Duration::from_secs(10)) .noise_threshold(0.02); targets = - create_table, - create_array, - create_string_table, - create_function, - call_lua_function, - call_sum_callback, - call_async_sum_callback, - call_concat_callback, - create_registry_values, - create_userdata, - call_userdata_index, - call_userdata_method, - call_async_userdata_method, + table_create_empty, + table_create_array, + table_create_hash, + table_get_set, + table_traversal_pairs, + table_traversal_for_each, + table_traversal_sequence, + table_ref_clone, + + function_create, + function_call_sum, + function_call_lua_sum, + function_call_concat, + function_call_lua_concat, + function_async_call_sum, + + registry_value_create, + registry_value_get, + + userdata_create, + userdata_call_index, + userdata_call_method, + userdata_call_method_complex, + userdata_async_call_method, } criterion_main!(benches); diff --git a/benches/serde.rs b/benches/serde.rs new file mode 100644 index 00000000..002061ac --- /dev/null +++ b/benches/serde.rs @@ -0,0 +1,90 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; + +use mlua::prelude::*; + +fn collect_gc_twice(lua: &Lua) { + lua.gc_collect().unwrap(); + lua.gc_collect().unwrap(); +} + +fn encode_json(c: &mut Criterion) { + let lua = Lua::new(); + + let encode = lua + .create_function(|_, t: LuaValue| Ok(serde_json::to_string(&t).unwrap())) + .unwrap(); + let table = lua + .load( + r#"{ + name = "Clark Kent", + address = { + city = "Smallville", + state = "Kansas", + country = "USA", + }, + age = 22, + parents = {"Jonathan Kent", "Martha Kent"}, + superman = true, + interests = {"flying", "saving the world", "kryptonite"}, + }"#, + ) + .eval::() + .unwrap(); + + c.bench_function("serialize json", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + encode.call::(&table).unwrap(); + }, + BatchSize::SmallInput, + ); + }); +} + +fn decode_json(c: &mut Criterion) { + let lua = Lua::new(); + + let decode = lua + .create_function(|lua, s: String| { + lua.to_value(&serde_json::from_str::(&s).unwrap()) + }) + .unwrap(); + let json = r#"{ + "name": "Clark Kent", + "address": { + "city": "Smallville", + "state": "Kansas", + "country": "USA" + }, + "age": 22, + "parents": ["Jonathan Kent", "Martha Kent"], + "superman": true, + "interests": ["flying", "saving the world", "kryptonite"] + }"#; + + c.bench_function("deserialize json", |b| { + b.iter_batched( + || collect_gc_twice(&lua), + |_| { + decode.call::(json).unwrap(); + }, + BatchSize::SmallInput, + ); + }); +} + +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(500) + .measurement_time(Duration::from_secs(10)) + .noise_threshold(0.02); + targets = + encode_json, + decode_json, +} + +criterion_main!(benches); diff --git a/build/find_dummy.rs b/build/find_dummy.rs deleted file mode 100644 index d8ad44b5..00000000 --- a/build/find_dummy.rs +++ /dev/null @@ -1,5 +0,0 @@ -use std::path::PathBuf; - -pub fn probe_lua() -> Option { - None -} diff --git a/build/find_normal.rs b/build/find_normal.rs deleted file mode 100644 index 3e72609d..00000000 --- a/build/find_normal.rs +++ /dev/null @@ -1,93 +0,0 @@ -#![allow(dead_code)] - -use std::env; -use std::ops::Bound; -use std::path::PathBuf; - -fn get_env_var(name: &str) -> String { - match env::var(name) { - Ok(val) => val, - Err(env::VarError::NotPresent) => String::new(), - Err(err) => panic!("cannot get {}: {}", name, err), - } -} - -pub fn probe_lua() -> Option { - let include_dir = get_env_var("LUA_INC"); - let lib_dir = get_env_var("LUA_LIB"); - let lua_lib = get_env_var("LUA_LIB_NAME"); - - println!("cargo:rerun-if-env-changed=LUA_INC"); - println!("cargo:rerun-if-env-changed=LUA_LIB"); - println!("cargo:rerun-if-env-changed=LUA_LIB_NAME"); - println!("cargo:rerun-if-env-changed=LUA_LINK"); - - let need_lua_lib = cfg!(any(not(feature = "module"), target_os = "windows")); - - if !include_dir.is_empty() { - if need_lua_lib { - if lib_dir.is_empty() { - panic!("LUA_LIB is not set"); - } - if lua_lib.is_empty() { - panic!("LUA_LIB_NAME is not set"); - } - - let mut link_lib = ""; - if get_env_var("LUA_LINK") == "static" { - link_lib = "static="; - }; - println!("cargo:rustc-link-search=native={}", lib_dir); - println!("cargo:rustc-link-lib={}{}", link_lib, lua_lib); - } - return Some(PathBuf::from(include_dir)); - } - - // Find using `pkg-config` - - #[cfg(feature = "lua54")] - let (incl_bound, excl_bound, alt_probe, ver) = ("5.4", "5.5", "lua5.4", "5.4"); - #[cfg(feature = "lua53")] - let (incl_bound, excl_bound, alt_probe, ver) = ("5.3", "5.4", "lua5.3", "5.3"); - #[cfg(feature = "lua52")] - let (incl_bound, excl_bound, alt_probe, ver) = ("5.2", "5.3", "lua5.2", "5.2"); - #[cfg(feature = "lua51")] - let (incl_bound, excl_bound, alt_probe, ver) = ("5.1", "5.2", "lua5.1", "5.1"); - - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "lua51" - ))] - { - let mut lua = pkg_config::Config::new() - .range_version((Bound::Included(incl_bound), Bound::Excluded(excl_bound))) - .cargo_metadata(need_lua_lib) - .probe("lua"); - - if lua.is_err() { - lua = pkg_config::Config::new() - .cargo_metadata(need_lua_lib) - .probe(alt_probe); - } - - lua.unwrap_or_else(|_| panic!("cannot find Lua {} using `pkg-config`", ver)) - .include_paths - .get(0) - .cloned() - } - - #[cfg(feature = "luajit")] - { - let lua = pkg_config::Config::new() - .range_version((Bound::Included("2.0.4"), Bound::Unbounded)) - .cargo_metadata(need_lua_lib) - .probe("luajit"); - - lua.expect("cannot find LuaJIT using `pkg-config`") - .include_paths - .get(0) - .cloned() - } -} diff --git a/build/main.rs b/build/main.rs deleted file mode 100644 index 5cb979a9..00000000 --- a/build/main.rs +++ /dev/null @@ -1,115 +0,0 @@ -#[cfg_attr( - any( - feature = "luau", - all( - feature = "vendored", - any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit" - ) - ) - ), - path = "find_vendored.rs" -)] -#[cfg_attr( - all( - not(feature = "vendored"), - any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit" - ) - ), - path = "find_normal.rs" -)] -#[cfg_attr( - not(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit", - feature = "luau" - )), - path = "find_dummy.rs" -)] -mod find; - -fn main() { - #[cfg(not(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit", - feature = "luau" - )))] - compile_error!( - "You must enable one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - #[cfg(all( - feature = "lua54", - any( - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit", - feature = "luau" - ) - ))] - compile_error!( - "You can enable only one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - #[cfg(all( - feature = "lua53", - any( - feature = "lua52", - feature = "lua51", - feature = "luajit", - feature = "luau" - ) - ))] - compile_error!( - "You can enable only one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - #[cfg(all( - feature = "lua52", - any(feature = "lua51", feature = "luajit", feature = "luau") - ))] - compile_error!( - "You can enable only one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - #[cfg(all(feature = "lua51", any(feature = "luajit", feature = "luau")))] - compile_error!( - "You can enable only one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - #[cfg(all(feature = "luajit", feature = "luau"))] - compile_error!( - "You can enable only one of the features: lua54, lua53, lua52, lua51, luajit, luajit52, luau" - ); - - // We don't support "vendored module" mode on windows - #[cfg(all(feature = "vendored", feature = "module", target_os = "windows"))] - compile_error!( - "Vendored (static) builds are not supported for modules on Windows.\n" - + "Please, use `pkg-config` or custom mode to link to a Lua dll." - ); - - #[cfg(all(feature = "luau", feature = "module"))] - compile_error!("Luau does not support module mode"); - - #[cfg(any(not(feature = "module"), target_os = "windows"))] - find::probe_lua(); - - println!("cargo:rerun-if-changed=build"); -} diff --git a/docs/release_notes/v0.10.md b/docs/release_notes/v0.10.md new file mode 100644 index 00000000..2684a4f3 --- /dev/null +++ b/docs/release_notes/v0.10.md @@ -0,0 +1,195 @@ +## mlua v0.10 release notes + +The v0.10 version of mlua has a goal to improve the user experience while keeping the same performance and safety guarantees. +This document highlights the most notable features. For a full list of changes, see the [CHANGELOG]. + +[CHANGELOG]: https://github.com/mlua-rs/mlua/blob/main/CHANGELOG.md + +### New features + +#### `'static` Lua types + +In previous mlua versions, it was required to have a `'lua` lifetime attached to every Lua value. v0.9 introduced (experimental) owned types that are `'static` without a lifetime attached, but they kept strong references to the Lua instance. +In v0.10 all Lua types are `'static` and have only weak reference to the Lua instance. It means they are more flexible and can be used in more places without worrying about memory leaks. + +#### Truly `send` feature + +In this version Lua is `Send + Sync` when the `send` feature flag is enabled (previously was only `Send`). It means Lua instance and their values can be safely shared between threads and used in multi threaded async contexts. + +```rust +let lua = Lua::new(); + +lua.globals().set("i", 0)?; +let func = lua.load("i = i + ...").into_function()?; + +std::thread::scope(|s| { + s.spawn(|| { + for i in 0..5 { + func.call::<()>(i).unwrap(); + } + }); + s.spawn(|| { + for i in 0..5 { + func.call::<()>(i).unwrap(); + } + }); +}); + +assert_eq!(lua.globals().get::("i")?, 20); +``` + +Under the hood, to synchronize access to the Lua state, mlua uses [`ReentrantMutex`] which can be recursively locked by a single thread. Only one thread can execute Lua code at a time, but it's possible to share Lua values between threads. + +This has some performance penalties (about 10-20%) compared to the lock free mode. This flag is disabled by default and is not supported in module mode. + +[`ReentrantMutex`]: https://docs.rs/parking_lot/latest/parking_lot/type.ReentrantMutex.html + +#### Register Rust functions with variable number of arguments + +The new traits `LuaNativeFn`/`LuaNativeFnMut`/`LuaNativeAsyncFn` have been introduced to provide a way to register Rust functions with variable number of arguments in Lua, without needing to pass all arguments as a tuple. + +They are used by `Function::wrap`/`Function::wrap_mut`/`Function::wrap_async` methods: + +```rust +let add = Function::wrap(|a: i64, b: i64| Ok(a + b)); + +lua.globals().set("add", add).unwrap(); + +// Prints 50 +lua.load(r#"print(add(5, 45))"#).exec().unwrap(); +``` + +To wrap functions that return direct value (non-`Result`) you can use `Function::wrap_raw` method. + +#### Setting metatable for Lua builtin types + +For Lua builtin types (like `string`, `function`, `number`, etc.) that have a shared metatable for all instances, it's now possible to set a custom metatable for them. + +```rust +let mt = lua.create_table()?; +mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { "2" } else { "0" }))?)?; +lua.set_type_metatable::(Some(mt)); +lua.load("assert(tostring(true) == '2')").exec().unwrap(); +``` + +### Improvements + +#### New `ObjectLike` trait + +The `ObjectLike` trait is a combination of the `AnyUserDataExt` and `TableExt` traits used in previous versions. It provides a unified interface for working with Lua tables and userdata. + +#### `Either` enum + +The `Either` enum is a simple enum that can hold either `L` or `R` value. It's useful when you need to return or receive one of two types in a function. +This type implements `IntoLua` and `FromLua` traits and can generate a meaningful error message when conversion fails. + +```rust +let func = Function::wrap(|x: Either| Ok(format!("received: {x}"))); + +lua.globals().set("func", func).unwrap(); + +// Prints: received: 123 +lua.load(r#"print(func(123))"#).exec().unwrap(); + +// Prints: bad argument #1: error converting Lua table to Either +lua.load(r#"print(pcall(func, {}))"#).exec().unwrap(); +``` + +#### `Lua::exec_raw` helper to execute low-level Lua C API code + +For advanced users, it's now possible to execute low-level Lua C API code using the `Lua::exec_raw` method. + +```rust +let t = lua.create_sequence_from([1, 2, 3, 4, 5])?; +let sum: i64 = unsafe { + lua.exec_raw(&t, |state| { + // top of the stack: table `t` + let mut sum = 0; + // push nil as the first key + mlua::ffi::lua_pushnil(state); + while mlua::ffi::lua_next(state, -2) != 0 { + sum += mlua::ffi::lua_tointeger(state, -1); + // Remove the value, keep the key for the next iteration + mlua::ffi::lua_pop(state, 1); + } + mlua::ffi::lua_pop(state, 1); + mlua::ffi::lua_pushinteger(state, sum); + // top of the stack: sum + }) +}?; +assert_eq!(sum, 15); +``` + +The `exec_raw` method is longjmp-safe. It's not recommended to move `Drop` types into the closure to avoid possible memory leaks. + +#### `anyhow` feature flag + +The new `anyhow` feature flag adds `IntoLua` and `Into` implementation for the `anyhow::Error` type. + +```rust +let f = lua.create_function(|_, ()| { + Err(anyhow!("error message"))?; + Ok(()) +})?; +``` + +### Breaking changes + +#### Scope changes + +The following `Scope` methods were changed: +- Removed `Scope::create_any_userdata` +- `Scope::create_nonstatic_userdata` is renamed to `Scope::create_userdata` + +Instead, scope has comprehensive support for borrowed userdata: `create_any_userdata_ref`, `create_any_userdata_ref_mut`, `create_userdata_ref`, `create_userdata_ref_mut`. + +`UserDataRef` and `UserDataRefMut` are no longer acceptable for scoped userdata access as they require owned underlying data. +In mlua v0.9 this could cause a read-after-free bug in some edge cases. + +To temporarily borrow underlying data, the `AnyUserData::borrow_scoped` and `AnyUserData::borrow_mut_scoped` methods were introduced: + +```rust +let data = "hello".to_string(); +lua.scope(|scope| { + let ud = scope.create_any_userdata_ref(&data)?; + + // We can only borrow scoped userdata using this method + ud.borrow_scoped::(|s| { + assert_eq!(s, "hello"); + })?; + + Ok(()) +})?; +``` + +Those methods work for scoped and regular userdata objects (but still require `T: 'static`). + +#### String changes + +Since `mlua::String` holds a weak reference to Lua without any guarantees about the lifetime of the underlying data, getting a `&str` or `&[u8]` from it is no longer safe. +Lua instance can be destroyed while reference to the data is still alive: + +```rust +let lua = Lua::new(); +let s: mlua::String = lua.create_string("hello, world")?; // only weak reference to Lua! +let s_ref: &str = s.to_str()?; // this is not safe! +drop(lua); +println!("{s_ref}"); // use after free! +``` + +To solve this issue, return types of `mlua::String::to_str` and `mlua::String::as_bytes` methods changed to `BorrowedStr` and `BorrowedBytes` respectively. + +These new types hold a strong reference to the Lua instance and can be safely converted to `&str` or `&[u8]`: + +```rust +let lua = Lua::new(); +let s: mlua::String = lua.create_string("hello, world")?; +let s_ref: mlua::BorrowedStr = s.to_str()?; // The strong reference to Lua is held here +drop(lua); +println!("{s_ref}"); // ok +``` + +The good news is that `BorrowedStr` implements `Deref`/`AsRef` as well as `Display`, `Debug`, `Eq`, `PartialEq` and other traits for easy usage. +The same applies to `BorrowedBytes`. + +Unfortunately, `mlua::String::to_string_lossy` cannot return `Cow<'a, str>` anymore, because it requires a strong reference to Lua. It now returns Rust `String` instead. diff --git a/docs/release_notes/v0.9.md b/docs/release_notes/v0.9.md new file mode 100644 index 00000000..9de0ce2d --- /dev/null +++ b/docs/release_notes/v0.9.md @@ -0,0 +1,361 @@ +## mlua v0.9 release notes + +The v0.9 version of mlua is a major release that includes a number of API changes and improvements. This release is a stepping stone towards the v1.0. +This document highlights the most important changes. For a full list of changes, see the [CHANGELOG]. + +[CHANGELOG]: https://github.com/mlua-rs/mlua/blob/main/CHANGELOG.md + +### New features + +#### 1. New Any UserData API + +This is a long awaited feature that allows to register in Lua foreign types that cannot implement `UserData` trait because of the Rust orphan rules. + +Now you can register any type that implements [`Any`] trait as a userdata type. + +Consider the following example: + +```rust +lua.register_userdata_type::(|reg| { + reg.add_method("len", |_, this, ()| Ok(this.len())); + + reg.add_method_mut("push", |_, this, s: String| { + this.push_str(&s); + Ok(()) + }); + + reg.add_meta_method(MetaMethod::ToString, |lua, this, ()| lua.create_string(this)); +})?; + +let s = lua.create_any_userdata("hello".to_string())?; +lua.load(chunk! { + print("s:len() is " .. $s:len()) + $s:push(" world") + // Prints: hello, world + print($s) +}) +.exec()?; +``` + +In this example we registered [`std::string::String`] as a userdata type with a set of methods and then created an instance of this type in Lua. + +It's _not_ required to register a type before using the `Lua::create_any_userdata()` method, instead an empty metatable will be created for you. +You can also register the same type multiple times with different methods. Any previously created instances will share the old metatable, while new instances will have the new one. + +The new set of API is called `any_userdata` because it allows to register types that implements [`Any`] trait. + +[`std::string::String`]: https://doc.rust-lang.org/stable/std/string/struct.String.html +[`Any`]: https://doc.rust-lang.org/stable/std/any/trait.Any.html + +#### 2. Scope support for the new any userdata types + +When you need to create non-static userdata instances in Lua, the usual way is use `Lua::scope()` helper to make them scoped. When out of scope, any scoped objects will be automatically +dropped. The only downside of this approach is that every new instance will have a new metatable. This is not very fast if you need to create a lot of instances. + +With the new Any UserData API, you can place non-static references `&T` where `T: 'static` into a scope and they will share a single static metatable. + +```rust +lua.register_userdata_type::(|reg| { + reg.add_method_mut("replace", |_, this, (pat, to): (String, String)| { + *this = this.replace(&pat, &to); + Ok(()) + }); + + reg.add_meta_method(MetaMethod::ToString, |lua, this, ()| lua.create_string(this)); +})?; + +let mut s = "hello, world".to_string(); + +lua.scope(|scope| { + // This userdata instance holds only a mutable reference to our string + let ud = scope.create_any_userdata_ref_mut(&mut s)?; + lua.load(chunk! { + $ud:replace("world", "user") + }) + .exec() +})?; + +// Prints: hello, user! +println!("{s}!"); +``` + +#### 3. Owned types (`unstable`) + +One of the common questions was how to embed a Lua type into Rust struct to use it later. It was non-trivial to do because of the `'lua` lifetime attached to every Lua value. + +In v0.9 mlua introduces "owned" types `OwnedTable`/`OwnedFunction`/`OwnedString`/`OwnedAnyUserData`/ `OwnedThread`that are `'static` (no lifetime attached). + +```rust +let lua = Lua::new(); + +struct MyStruct { + table: OwnedTable, + func: OwnedFunction, +} + +let my_struct = MyStruct { + table: lua.globals().into_owned(), + func: lua + .create_function(|_, t: Table| Ok(format!("{t:#?}")))? + .into_owned(), +}; + +// It's safe to drop Lua! +drop(lua); + +let result = my_struct.func.call::<_, String>(my_struct.table)?; +println!("{result}"); +``` + +Prior to v0.9, it was possible to do by creating a reference to the Lua value in registry using `Lua::create_registry_value()` +and retrieving value later using `Lua::registry_value()` method. + +All owned handles hold a *strong* reference to the current Lua instance. +Be warned, if you place them into a Lua type (eg. `UserData` or a Rust callback), it is *very easy* +to accidentally cause reference cycles that would prevent destroying Lua instance. + +Please note this functionality is available under the `unstable` feature flag and not available when the `send` feature is enabled. + +#### New ffi module + +In v0.9 release the internal `ffi` module has been moved into the new [`mlua-sys`] crate and became available for public use. +This crate provides unified Lua FFI API (targeting Lua 5.4) using a (limited) compatibility layer for older versions. + +mlua re-exports the `ffi` module aliasing the `mlua-sys` crate and provides (unsafe) functionality to work with raw Lua state: + +```rust +unsafe { + unsafe extern "C-unwind" fn lua_add(state: *mut mlua::lua_State) -> i32 { + let a = mlua::ffi::luaL_checkinteger(state, 1); + let b = mlua::ffi::luaL_checkinteger(state, 2); + mlua::ffi::lua_pushinteger(state, a + b); + 1 + } + + let add = lua.create_c_function(lua_add)?; + assert_eq!(add.call::<_, i32>((2, 3))?, 5); +} +``` + +[`mlua-sys`]: https://crates.io/crates/mlua-sys + +#### Luau JIT support + +mlua brings support for the new [Luau] JIT backend under the `luau-jit` feature flag. + +It will automatically trigger JIT compilation for new Lua chunks. To disable it, just call `lua.enable_jit(false)` before loading Lua code +(but any previously compiled chunks will remain JIT-compiled). + +[Luau]: https://luau-lang.org + +### Improvements + +#### 1. Better error reporting + +When calling a Rust function from Lua and passing wrong arguments, previous mlua versions reported an error message without any context or reference to the particular argument. + +In v0.9 it reports an error message with the argument index and expected type: + +```rust +let func = lua.create_function(|_, _a: i32| Ok(()))?; +lua.load(chunk! { + local ok, err = pcall($func, "not a number") + // Prints: bad argument #1: error converting Lua string to i32 (expected number or string coercible to number) + print(err) +}) +.exec()?; +``` + +Similar changes have been made for userdata functions and methods: + +```rust +lua.register_userdata_type::<&'static str>(|reg| { + reg.add_method("len", |_, this, ()| Ok(this.len())); +})?; + +let s = lua.create_any_userdata("hello")?; +lua.load(chunk! { + local ok, err = pcall($s.len, 123) + // Prints: bad argument `self` to `&str.len`: error converting Lua integer to userdata + print(err) +}) +.exec()?; +``` + +#### 2. Error context + +Similar to the [`anyhow`] Error type, now it's possible to attach context to Lua errors: + +```rust +let read = lua.create_function(|lua, path: String| { + let bytes = std::fs::read(&path) + .into_lua_err() + .context(format!("Failed to open `{path}`"))?; + Ok(lua.create_string(bytes)) +})?; + +lua.load(chunk! { + local ok, err = pcall($read, "/nonexistent") + /// Prints: + /// Failed to open /nonexistent + /// No such file or directory (os error 2) + /// stack traceback: + /// ... + print(err) +}) +.exec()?; +``` + +[`anyhow`]: https://crates.io/crates/anyhow + +#### 4. New methods `Function::wrap`/`AnyUserData::wrap` + +Sometimes it's useful to have `IntoLua` trait implementation for a Rust function or type `T: Any` without needing to call `Lua::create_function()`/`Lua::create_any_userdata()` methods. +Since v0.9 you can call the new methods `Function::wrap()`/`AnyUserData::wrap()` that allows to do this. They return an abstract type that `impl IntoLua`: + +```rust +lua.globals().set("print_rust", Function::wrap(|_, s: String| Ok(println!("{}", s))))?; +lua.globals().set("rust_ud", AnyUserData::wrap("hello"))?; +``` + +In addition there are also `Function::wrap_mut()`/`Function::wrap_async()` methods that allow to wrap mutable and async functions respectively. + +For a `T: 'UserData + 'static` the `IntoLua` trait is still always implemented. + +#### `UserDataRef` and `UserDataRefMut` type wrappers + +The new wrappers `UserDataRef` and `UserDataRefMut` are receivers for userdata type `T` and borrow underlying instance for the lifetime of the wrapper. + +```rust +lua.globals() + .set("ud", AnyUserData::wrap("hello".to_string()))?; + +let mut ud_mut: UserDataRefMut = lua.globals().get("ud")?; +ud_mut.push_str(", Rust"); +drop(ud_mut); + +let ud_ref: UserDataRef = lua.globals().get("ud")?; +// Prints: hello, Rust +println!("{}", *ud_ref); +``` + +In the previous mlua versions the same functionality can be achieved by receiving `AnyUserData` and calling `AnyUserData::borrow()`/`AnyUserData::borrow_mut()` methods. + +The new wrappers are identical to Rust [`Ref`]/[`RefMut`] types. + +[`Ref`]: https://doc.rust-lang.org/std/cell/struct.Ref.html +[`RefMut`]: https://doc.rust-lang.org/std/cell/struct.RefMut.html + +#### New `AnyUserDataExt` trait + +Similar to the `TableExt` trait, the `AnyUserDataExt` provides a set of extra methods for the `AnyUserData` type. + +1) `AnyUserDataExt::get()/set()` to get/set a value by key from the userdata, assuming it has `__index` metamethod. + +2) `AnyUserDataExt::call()` to call the userdata as a function assuming it has `__call` metamethod. + +3) `AnyUserData::call_method(name, ...)` to call the userdata method, assuming it has `__index` metamethod and the associated function. + +#### Pretty formatting Lua values + +`mlua::Value` implements a new format `:#?` that allows to (recursively) pretty print Lua values: + +```rust +println!("{:#?}", lua.globals()); +``` + +Prints: +``` +{ + ["_G"] = table: 0x7fa2d0706260, + ["_VERSION"] = "Lua 5.4", + ["assert"] = function: 0x10451d11d, + ["collectgarbage"] = function: 0x10451d198, + ["coroutine"] = { + ["close"] = function: 0x10451e28f, + ... + }, + ["dofile"] = function: 0x10451d37c, + ... +} +``` + +In addition a new method `Value::to_string()` has been added to convert `Value` to a string (using `__tostring` metamethod if available). + +#### Environment for Lua functions + +Any Lua functions have an associated environment table that is used to resolve global variables. By default it sets to a Lua globals table. + +In the new release it's possible to get or update a function environment using `Function::environment()` or `Function::set_environment()` methods respectively. + +```rust +let f = lua.load("return a").into_function()?; + +assert_eq!(f.environment(), Some(lua.globals())); + +lua.globals().set("a", 1)?; +assert_eq!(f.call::<_, i32>(())?, 1); + +f.set_environment(lua.create_table_from([("a", "hello")])?)?; +assert_eq!(f.call::<_, mlua::String>(())?, "hello"); +``` + +#### Performance optimizations + +The new mlua version has a number of performance improvements. Please check the [benchmarks results] to see how mlua compares to rlua and rhai. + +[benchmarks results]: https://github.com/mlua-rs/script-bench-rs + +### Changes in `module` mode + +#### New attributes + +The `lua_module` macro now support the following attributes: + +- `name=...` - sets name of the module (defaults to the name of the function). + +Eg.: + +```rust +#[mlua::lua_module(name = "alt_module")] +fn my_module(lua: &Lua) -> LuaResult { + lua.create_table() +} +``` + +Under the hood a new function `luaopen_alt_module` will be created for the Lua module loader. + +- `skip_memory_check` - skip memory allocation checks for some operations. + +In module mode, mlua runs in an unknown environment and cannot tell whether there are any memory limits or not. As a result, some operations that require memory allocation run in +protected mode. Setting this attribute will improve performance of such operations with risk of having uncaught exceptions and memory leaks. + +#### Improved Windows target + +In previous mlua versions, building a Lua module for Windows requires having Lua development libraries installed on the system. +In contrast, on Linux and macOS, modules can be built without any external dependencies using the `-undefined=dynamic_lookup` linker flag. + +With Rust 1.71+ it's now possible to lift this restriction for Windows as well. You can build modules normally and they will be linked with +`lua5x.dll` depending on the enabled Lua version. + +You still need to have the dll although, linked to application where the module will be loaded. + +### Breaking changes + +1) `ToLua`/`ToLuaMulti` traits have been renamed to `IntoLua`/`IntoLuaMulti` respectively (with the methods called `into_lua`/`into_lua_multi`). + +The main reason for this change is following the Rust self [convention](https://rust-lang.github.io/rust-clippy/master/index.html#/wrong_self_convention). + +2) Removed `FromLua` implementation for `T: UserData + Clone`. + +During the usage of mlua, it was found that this implementation is not very useful and prevents custom `FromLua` implementations for `T: UserData`. +It should be a developer decision to opt-in `FromLua` for their `T` if needed rather than having enabled it unconditionally. + +To opt-in `FromLua` for `T: Clone` you can use a simple `#[derive(FromLua)]` macro (requires `feature = "macros"`): + +```rust +#[derive(Clone, Copy, mlua::FromLua)] +struct MyUserData(i32); +``` + +`T` is not required to implement `UserData` because of the new relaxed restrictions on userdata types. diff --git a/examples/async_http_client.rs b/examples/async_http_client.rs index e7918082..a2d9ed83 100644 --- a/examples/async_http_client.rs +++ b/examples/async_http_client.rs @@ -1,33 +1,36 @@ use std::collections::HashMap; -use hyper::body::{Body as HyperBody, HttpBody as _}; -use hyper::Client as HyperClient; +use http_body_util::BodyExt as _; +use hyper::body::Incoming; +use hyper_util::client::legacy::Client as HyperClient; +use hyper_util::rt::TokioExecutor; -use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods}; +use mlua::{ExternalResult, Lua, Result, UserData, UserDataMethods, chunk}; -struct BodyReader(HyperBody); +struct BodyReader(Incoming); impl UserData for BodyReader { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_function("read", |lua, reader: AnyUserData| async move { - let mut reader = reader.borrow_mut::()?; - if let Some(bytes) = reader.0.data().await { - let bytes = bytes.to_lua_err()?; - return Some(lua.create_string(&bytes)).transpose(); + fn add_methods>(methods: &mut M) { + // Every call returns a next chunk + methods.add_async_method_mut("read", |lua, mut reader, ()| async move { + if let Some(bytes) = reader.0.frame().await { + if let Some(bytes) = bytes.into_lua_err()?.data_ref() { + return Some(lua.create_string(&bytes)).transpose(); + } } Ok(None) }); } } -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() -> Result<()> { let lua = Lua::new(); let fetch_url = lua.create_async_function(|lua, uri: String| async move { - let client = HyperClient::new(); - let uri = uri.parse().to_lua_err()?; - let resp = client.get(uri).await.to_lua_err()?; + let client = HyperClient::builder(TokioExecutor::new()).build_http::(); + let uri = uri.parse().into_lua_err()?; + let resp = client.get(uri).await.into_lua_err()?; let lua_resp = lua.create_table()?; lua_resp.set("status", resp.status().as_u16())?; @@ -37,7 +40,7 @@ async fn main() -> Result<()> { headers .entry(key.as_str()) .or_insert(Vec::new()) - .push(value.to_str().to_lua_err()?); + .push(value.to_str().into_lua_err()?); } lua_resp.set("headers", headers)?; @@ -56,11 +59,11 @@ async fn main() -> Result<()> { end end repeat - local body = res.body:read() - if body then - print(body) + local chunk = res.body:read() + if chunk then + print(chunk) end - until not body + until not chunk }) .into_function()?; diff --git a/examples/async_http_reqwest.rs b/examples/async_http_reqwest.rs index 78c425bb..91206d4a 100644 --- a/examples/async_http_reqwest.rs +++ b/examples/async_http_reqwest.rs @@ -1,33 +1,27 @@ -use mlua::{chunk, ExternalResult, Lua, LuaSerdeExt, Result}; +use mlua::{ExternalResult, Lua, LuaSerdeExt, Result, Value, chunk}; -#[tokio::main] +#[tokio::main(flavor = "current_thread")] async fn main() -> Result<()> { let lua = Lua::new(); - let null = lua.null(); - let fetch_json = lua.create_async_function(|lua, uri: String| async move { let resp = reqwest::get(&uri) .await .and_then(|resp| resp.error_for_status()) - .to_lua_err()?; - let json = resp.json::().await.to_lua_err()?; + .into_lua_err()?; + let json = resp.json::().await.into_lua_err()?; lua.to_value(&json) })?; + let dbg = lua.create_function(|_, value: Value| { + println!("{value:#?}"); + Ok(()) + })?; + let f = lua .load(chunk! { - function print_r(t, indent) - local indent = indent or "" - for k, v in pairs(t) do - io.write(indent, tostring(k)) - if type(v) == "table" then io.write(":\n") print_r(v, indent.." ") - else io.write(": ", v == $null and "null" or tostring(v), "\n") end - end - end - local res = $fetch_json(...) - print_r(res) + $dbg(res) }) .into_function()?; diff --git a/examples/async_http_server.rs b/examples/async_http_server.rs index 43ae7a95..f5057ed6 100644 --- a/examples/async_http_server.rs +++ b/examples/async_http_server.rs @@ -1,60 +1,70 @@ +use std::convert::Infallible; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use hyper::server::conn::AddrStream; -use hyper::service::Service; -use hyper::{Body, Request, Response, Server}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt as _, Empty, Full}; +use hyper::body::{Bytes, Incoming}; +use hyper::server::conn::http1; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; -use mlua::{ - chunk, Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods, -}; +use mlua::{Error as LuaError, Function, Lua, String as LuaString, Table, UserData, UserDataMethods, chunk}; -struct LuaRequest(SocketAddr, Request); +/// Wrapper around incoming request that implements UserData +struct LuaRequest(SocketAddr, Request); impl UserData for LuaRequest { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("remote_addr", |_lua, req, ()| Ok((req.0).to_string())); - methods.add_method("method", |_lua, req, ()| Ok((req.1).method().to_string())); + fn add_methods>(methods: &mut M) { + methods.add_method("remote_addr", |_, req, ()| Ok((req.0).to_string())); + methods.add_method("method", |_, req, ()| Ok((req.1).method().to_string())); + methods.add_method("path", |_, req, ()| Ok(req.1.uri().path().to_string())); } } -pub struct Svc(Rc, SocketAddr); - -impl Service> for Svc { - type Response = Response; - type Error = LuaError; - type Future = Pin>>>; +/// Service that handles incoming requests +#[derive(Clone)] +pub struct Svc { + handler: Function, + peer_addr: SocketAddr, +} - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) +impl Svc { + pub fn new(handler: Function, peer_addr: SocketAddr) -> Self { + Self { handler, peer_addr } } +} + +impl hyper::service::Service> for Svc { + type Response = Response>; + type Error = LuaError; + type Future = Pin> + Send>>; - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { // If handler returns an error then generate 5xx response - let lua = self.0.clone(); - let lua_req = LuaRequest(self.1, req); + let handler = self.handler.clone(); + let lua_req = LuaRequest(self.peer_addr, req); Box::pin(async move { - let handler: Function = lua.named_registry_value("http_handler")?; - match handler.call_async::<_, Table>(lua_req).await { + match handler.call_async::(lua_req).await { Ok(lua_resp) => { - let status = lua_resp.get::<_, Option>("status")?.unwrap_or(200); + let status = lua_resp.get::>("status")?.unwrap_or(200); let mut resp = Response::builder().status(status); // Set headers - if let Some(headers) = lua_resp.get::<_, Option
>("headers")? { + if let Some(headers) = lua_resp.get::>("headers")? { for pair in headers.pairs::() { let (h, v) = pair?; - resp = resp.header(&h, v.as_bytes()); + resp = resp.header(&h, &*v.as_bytes()); } } + // Set body let body = lua_resp - .get::<_, Option>("body")? - .map(|b| Body::from(b.as_bytes().to_vec())) - .unwrap_or_else(Body::empty); + .get::>("body")? + .map(|b| Full::new(Bytes::copy_from_slice(&b.as_bytes())).boxed()) + .unwrap_or_else(|| Empty::::new().boxed()); Ok(resp.body(body).unwrap()) } @@ -62,7 +72,7 @@ impl Service> for Svc { eprintln!("{}", err); Ok(Response::builder() .status(500) - .body(Body::from("Internal Server Error")) + .body(Full::new(Bytes::from("Internal Server Error")).boxed()) .unwrap()) } } @@ -72,65 +82,47 @@ impl Service> for Svc { #[tokio::main(flavor = "current_thread")] async fn main() { - let lua = Rc::new(Lua::new()); + let lua = Lua::new(); // Create Lua handler function - let handler: Function = lua + let handler = lua .load(chunk! { function(req) return { status = 200, headers = { ["X-Req-Method"] = req:method(), + ["X-Req-Path"] = req:path(), ["X-Remote-Addr"] = req:remote_addr(), }, body = "Hello from Lua!\n" } end }) - .eval() - .expect("cannot create Lua handler"); - - // Store it in the Registry - lua.set_named_registry_value("http_handler", handler) - .expect("cannot store Lua handler"); - - let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).executor(LocalExec).serve(MakeSvc(lua)); - - println!("Listening on http://{}", addr); - - // Create `LocalSet` to spawn !Send futures - let local = tokio::task::LocalSet::new(); - local.run_until(server).await.expect("cannot run server") -} - -struct MakeSvc(Rc); - -impl Service<&AddrStream> for MakeSvc { - type Response = Svc; - type Error = hyper::Error; - type Future = Pin>>>; - - fn poll_ready(&mut self, _: &mut Context) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, stream: &AddrStream) -> Self::Future { - let lua = self.0.clone(); - let remote_addr = stream.remote_addr(); - Box::pin(async move { Ok(Svc(lua, remote_addr)) }) - } -} - -#[derive(Clone, Copy, Debug)] -struct LocalExec; - -impl hyper::rt::Executor for LocalExec -where - F: std::future::Future + 'static, // not requiring `Send` -{ - fn execute(&self, fut: F) { - tokio::task::spawn_local(fut); + .eval::() + .expect("Failed to create Lua handler"); + + let listen_addr = "127.0.0.1:3000"; + let listener = TcpListener::bind(listen_addr).await.unwrap(); + println!("Listening on http://{listen_addr}"); + + loop { + let (stream, peer_addr) = match listener.accept().await { + Ok(x) => x, + Err(err) => { + eprintln!("Failed to accept connection: {err}"); + continue; + } + }; + + let svc = Svc::new(handler.clone(), peer_addr); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection(TokioIo::new(stream), svc) + .await + { + eprintln!("Error serving connection: {:?}", err); + } + }); } } diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs index edfc1146..c78d9d8e 100644 --- a/examples/async_tcp_server.rs +++ b/examples/async_tcp_server.rs @@ -1,59 +1,42 @@ use std::io; use std::net::SocketAddr; -use std::rc::Rc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::task; -use mlua::{ - chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods, -}; +use mlua::{BString, Function, Lua, UserData, UserDataMethods, chunk}; struct LuaTcpStream(TcpStream); impl UserData for LuaTcpStream { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("peer_addr", |_, this, ()| { - Ok(this.0.peer_addr()?.to_string()) + fn add_methods>(methods: &mut M) { + methods.add_method("peer_addr", |_, this, ()| Ok(this.0.peer_addr()?.to_string())); + + methods.add_async_method_mut("read", |lua, mut this, size| async move { + let mut buf = vec![0; size]; + let n = this.0.read(&mut buf).await?; + buf.truncate(n); + lua.create_string(&buf) }); - methods.add_async_function( - "read", - |lua, (this, size): (AnyUserData, usize)| async move { - let mut this = this.borrow_mut::()?; - let mut buf = vec![0; size]; - let n = this.0.read(&mut buf).await?; - buf.truncate(n); - lua.create_string(&buf) - }, - ); - - methods.add_async_function( - "write", - |_, (this, data): (AnyUserData, LuaString)| async move { - let mut this = this.borrow_mut::()?; - let n = this.0.write(&data.as_bytes()).await?; - Ok(n) - }, - ); - - methods.add_async_function("close", |_, this: AnyUserData| async move { - let mut this = this.borrow_mut::()?; + methods.add_async_method_mut("write", |_, mut this, data: BString| async move { + let n = this.0.write(&data).await?; + Ok(n) + }); + + methods.add_async_method_mut("close", |_, mut this, ()| async move { this.0.shutdown().await?; Ok(()) }); } } -async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> { +async fn run_server(handler: Function) -> io::Result<()> { let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); let listener = TcpListener::bind(addr).await.expect("cannot bind addr"); println!("Listening on {}", addr); - let lua = Rc::new(lua); - let handler = Rc::new(handler); loop { let (stream, _) = match listener.accept().await { Ok(res) => res, @@ -61,15 +44,10 @@ async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> { Err(err) => return Err(err), }; - let lua = lua.clone(); let handler = handler.clone(); - task::spawn_local(async move { - let handler: Function = lua - .registry_value(&handler) - .expect("cannot get Lua handler"); - + tokio::task::spawn(async move { let stream = LuaTcpStream(stream); - if let Err(err) = handler.call_async::<_, ()>(stream).await { + if let Err(err) = handler.call_async::<()>(stream).await { eprintln!("{}", err); } }); @@ -81,7 +59,7 @@ async fn main() { let lua = Lua::new(); // Create Lua handler function - let handler_fn = lua + let handler = lua .load(chunk! { function(stream) local peer_addr = stream:peer_addr() @@ -103,15 +81,7 @@ async fn main() { .eval::() .expect("cannot create Lua handler"); - // Store it in the Registry - let handler = lua - .create_registry_value(handler_fn) - .expect("cannot store Lua handler"); - - task::LocalSet::new() - .run_until(run_server(lua, handler)) - .await - .expect("cannot run server") + run_server(handler).await.expect("cannot run server") } fn is_transient_error(e: &io::Error) -> bool { diff --git a/examples/guided_tour.rs b/examples/guided_tour.rs index ddc6d1f3..ba8b3ac4 100644 --- a/examples/guided_tour.rs +++ b/examples/guided_tour.rs @@ -1,7 +1,7 @@ use std::f32; use std::iter::FromIterator; -use mlua::{chunk, Function, Lua, MetaMethod, Result, UserData, UserDataMethods, Variadic}; +use mlua::{FromLua, Function, Lua, MetaMethod, Result, UserData, UserDataMethods, Value, Variadic, chunk}; fn main() -> Result<()> { // You can create a new Lua state with `Lua::new()`. This loads the default Lua std library @@ -17,12 +17,12 @@ fn main() -> Result<()> { globals.set("string_var", "hello")?; globals.set("int_var", 42)?; - assert_eq!(globals.get::<_, String>("string_var")?, "hello"); - assert_eq!(globals.get::<_, i64>("int_var")?, 42); + assert_eq!(globals.get::("string_var")?, "hello"); + assert_eq!(globals.get::("int_var")?, 42); // You can load and evaluate Lua code. The returned type of `Lua::load` is a builder // that allows you to change settings before running Lua code. Here, we are using it to set - // the name of the laoded chunk to "example code", which will be used when Lua error + // the name of the loaded chunk to "example code", which will be used when Lua error // messages are printed. lua.load( @@ -30,9 +30,9 @@ fn main() -> Result<()> { global = 'foo'..'bar' "#, ) - .set_name("example code")? + .set_name("example code") .exec()?; - assert_eq!(globals.get::<_, String>("global")?, "foobar"); + assert_eq!(globals.get::("global")?, "foobar"); assert_eq!(lua.load("1 + 1").eval::()?, 2); assert_eq!(lua.load("false == false").eval::()?, true); @@ -85,20 +85,20 @@ fn main() -> Result<()> { // You can load Lua functions let print: Function = globals.get("print")?; - print.call::<_, ()>("hello from rust")?; + print.call::<()>("hello from rust")?; - // This API generally handles variadics using tuples. This is one way to call a function with + // This API generally handles variadic using tuples. This is one way to call a function with // multiple parameters: - print.call::<_, ()>(("hello", "again", "from", "rust"))?; + print.call::<()>(("hello", "again", "from", "rust"))?; // But, you can also pass variadic arguments with the `Variadic` type. - print.call::<_, ()>(Variadic::from_iter( + print.call::<()>(Variadic::from_iter( ["hello", "yet", "again", "from", "rust"].iter().cloned(), ))?; - // You can bind rust functions to Lua as well. Callbacks receive the Lua state inself as their + // You can bind rust functions to Lua as well. Callbacks receive the Lua state itself as their // first parameter, and the arguments given to the function as the second parameter. The type // of the arguments can be anything that is convertible from the parameters given by Lua, in // this case, the function expects two string sequences. @@ -151,8 +151,18 @@ fn main() -> Result<()> { #[derive(Copy, Clone)] struct Vec2(f32, f32); + // We can implement `FromLua` trait for our `Vec2` to return a copy + impl FromLua for Vec2 { + fn from_lua(value: Value, _: &Lua) -> Result { + match value { + Value::UserData(ud) => Ok(*ud.borrow::()?), + _ => unreachable!(), + } + } + } + impl UserData for Vec2 { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("magnitude", |_, vec, ()| { let mag_squared = vec.0 * vec.0 + vec.1 * vec.1; Ok(mag_squared.sqrt()) @@ -167,19 +177,15 @@ fn main() -> Result<()> { let vec2_constructor = lua.create_function(|_, (x, y): (f32, f32)| Ok(Vec2(x, y)))?; globals.set("vec2", vec2_constructor)?; - assert!( - (lua.load("(vec2(1, 2) + vec2(2, 2)):magnitude()") - .eval::()? - - 5.0) - .abs() - < f32::EPSILON - ); + assert!((lua.load("(vec2(1, 2) + vec2(2, 2)):magnitude()").eval::()? - 5.0).abs() < f32::EPSILON); // Normally, Rust types passed to `Lua` must be `'static`, because there is no way to be // sure of their lifetime inside the Lua state. There is, however, a limited way to lift this // requirement. You can call `Lua::scope` to create userdata and callbacks types that only live // for as long as the call to scope, but do not have to be `'static` (and `Send`). + // TODO: Re-enable this + /* { let mut rust_val = 0; @@ -201,6 +207,7 @@ fn main() -> Result<()> { assert_eq!(rust_val, 42); } + */ // We were able to run our 'sketchy' function inside the scope just fine. However, if we // try to run our 'sketchy' function outside of the scope, the function we created will have diff --git a/examples/module/Cargo.toml b/examples/module/Cargo.toml index 504c01e2..588a4fea 100644 --- a/examples/module/Cargo.toml +++ b/examples/module/Cargo.toml @@ -2,7 +2,7 @@ name = "rust_module" version = "0.0.0" authors = ["Aleksandr Orlenko "] -edition = "2018" +edition = "2021" [lib] crate-type = ["cdylib"] @@ -10,6 +10,7 @@ crate-type = ["cdylib"] [workspace] [features] +lua55 = ["mlua/lua55"] lua54 = ["mlua/lua54"] lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] diff --git a/examples/repl.rs b/examples/repl.rs index a5fe47a6..98355cea 100644 --- a/examples/repl.rs +++ b/examples/repl.rs @@ -1,11 +1,11 @@ //! This example shows a simple read-evaluate-print-loop (REPL). use mlua::{Error, Lua, MultiValue}; -use rustyline::Editor; +use rustyline::DefaultEditor; fn main() { let lua = Lua::new(); - let mut editor = Editor::<()>::new(); + let mut editor = DefaultEditor::new().expect("Failed to create editor"); loop { let mut prompt = "> "; @@ -19,15 +19,17 @@ fn main() { match lua.load(&line).eval::() { Ok(values) => { - editor.add_history_entry(line); - println!( - "{}", - values - .iter() - .map(|value| format!("{:?}", value)) - .collect::>() - .join("\t") - ); + editor.add_history_entry(line).unwrap(); + if values.len() > 0 { + println!( + "{}", + values + .iter() + .map(|value| format!("{:#?}", value)) + .collect::>() + .join("\t") + ); + } break; } Err(Error::SyntaxError { diff --git a/examples/serialize.rs b/examples/serde.rs similarity index 93% rename from examples/serialize.rs rename to examples/serde.rs index 7c9ae69e..fff968c4 100644 --- a/examples/serialize.rs +++ b/examples/serde.rs @@ -28,9 +28,14 @@ fn main() -> Result<()> { let globals = lua.globals(); // Create Car struct from a Lua table - let car: Car = lua.from_value(lua.load(r#" + let car: Car = lua.from_value( + lua.load( + r#" {active = true, model = "Volkswagen Golf", transmission = "Automatic", engine = {v = 1499, kw = 90}} - "#).eval()?)?; + "#, + ) + .eval()?, + )?; // Set it as (serializable) userdata globals.set("null", lua.null())?; diff --git a/examples/userdata.rs b/examples/userdata.rs index 8823dfd5..6a21e90b 100644 --- a/examples/userdata.rs +++ b/examples/userdata.rs @@ -1,4 +1,4 @@ -use mlua::{chunk, Lua, MetaMethod, Result, UserData}; +use mlua::{Lua, MetaMethod, Result, UserData, chunk}; #[derive(Default)] struct Rectangle { @@ -7,7 +7,7 @@ struct Rectangle { } impl UserData for Rectangle { - fn add_fields<'lua, F: mlua::UserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { fields.add_field_method_get("length", |_, this| Ok(this.length)); fields.add_field_method_set("length", |_, this, val| { this.length = val; @@ -20,7 +20,7 @@ impl UserData for Rectangle { }); } - fn add_methods<'lua, M: mlua::UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("area", |_, this, ()| Ok(this.length * this.width)); methods.add_method("diagonal", |_, this, ()| { Ok((this.length.pow(2) as f64 + this.width.pow(2) as f64).sqrt()) diff --git a/mlua-sys/Cargo.toml b/mlua-sys/Cargo.toml new file mode 100644 index 00000000..e1bb93ce --- /dev/null +++ b/mlua-sys/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "mlua-sys" +version = "0.11.0-rc.1" +authors = ["Aleksandr Orlenko "] +rust-version = "1.88" +edition = "2024" +repository = "https://github.com/mlua-rs/mlua" +documentation = "https://docs.rs/mlua-sys" +readme = "README.md" +categories = ["external-ffi-bindings"] +license = "MIT" +links = "lua" +build = "build/main.rs" +description = """ +Low level (FFI) bindings to Lua 5.5/5.4/5.3/5.2/5.1 (including LuaJIT) and Luau +""" + +[package.metadata.docs.rs] +features = ["lua55", "vendored"] +rustdoc-args = ["--cfg", "docsrs"] + +[features] +lua55 = [] +lua54 = [] +lua53 = [] +lua52 = [] +lua51 = [] +luajit = [] +luajit52 = ["luajit"] +luau = ["luau0-src"] +luau-codegen = ["luau"] +luau-vector4 = ["luau"] +vendored = ["lua-src", "luajit-src"] +external = [] +module = [] + +[dependencies] +libc = "0.2" + +[build-dependencies] +cc = "1.0" +cfg-if = "1.0" +pkg-config = "0.3.17" +lua-src = { version = ">= 550.1.0, < 550.2.0", optional = true } +luajit-src = { version = ">= 210.7.0, < 210.8.0", optional = true } +luau0-src = { version = "0.20.0", optional = true } + +[lints.rust] +unexpected_cfgs = { level = "allow", check-cfg = ['cfg(raw_dylib)'] } diff --git a/mlua-sys/README.md b/mlua-sys/README.md new file mode 100644 index 00000000..def6e91e --- /dev/null +++ b/mlua-sys/README.md @@ -0,0 +1,9 @@ +# mlua-sys + +Low level (FFI) bindings to Lua 5.5/5.4/5.3/5.2/5.1 (including [LuaJIT]) and [Luau]. + +Intended to be consumed by the [mlua] crate. + +[LuaJIT]: https://github.com/LuaJIT/LuaJIT +[Luau]: https://github.com/luau-lang/luau +[mlua]: https://crates.io/crates/mlua diff --git a/mlua-sys/build/find_normal.rs b/mlua-sys/build/find_normal.rs new file mode 100644 index 00000000..b91b1c01 --- /dev/null +++ b/mlua-sys/build/find_normal.rs @@ -0,0 +1,64 @@ +#![allow(dead_code)] + +use std::env; +use std::ops::Bound; + +pub fn probe_lua() { + let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap(); + + if target_arch == "wasm32" && cfg!(not(feature = "vendored")) { + panic!("Please enable `vendored` feature to build for wasm32"); + } + + let lib_dir = env::var("LUA_LIB").unwrap_or_default(); + let lua_lib = env::var("LUA_LIB_NAME").unwrap_or_default(); + + println!("cargo:rerun-if-env-changed=LUA_LIB"); + println!("cargo:rerun-if-env-changed=LUA_LIB_NAME"); + println!("cargo:rerun-if-env-changed=LUA_LINK"); + + if !lua_lib.is_empty() { + if !lib_dir.is_empty() { + println!("cargo:rustc-link-search=native={lib_dir}"); + } + let mut link_lib = ""; + if env::var("LUA_LINK").as_deref() == Ok("static") { + link_lib = "static="; + }; + println!("cargo:rustc-link-lib={link_lib}{lua_lib}"); + return; + } + + // Find using `pkg-config` + + #[cfg(feature = "lua55")] + let (incl_bound, excl_bound, alt_probe, ver) = ("5.5", "5.6", ["lua5.5", "lua-5.5", "lua55"], "5.5"); + #[cfg(feature = "lua54")] + let (incl_bound, excl_bound, alt_probe, ver) = ("5.4", "5.5", ["lua5.4", "lua-5.4", "lua54"], "5.4"); + #[cfg(feature = "lua53")] + let (incl_bound, excl_bound, alt_probe, ver) = ("5.3", "5.4", ["lua5.3", "lua-5.3", "lua53"], "5.3"); + #[cfg(feature = "lua52")] + let (incl_bound, excl_bound, alt_probe, ver) = ("5.2", "5.3", ["lua5.2", "lua-5.2", "lua52"], "5.2"); + #[cfg(feature = "lua51")] + let (incl_bound, excl_bound, alt_probe, ver) = ("5.1", "5.2", ["lua5.1", "lua-5.1", "lua51"], "5.1"); + #[cfg(feature = "luajit")] + let (incl_bound, excl_bound, alt_probe, ver) = ("2.0.4", "2.2", [], "JIT"); + + #[rustfmt::skip] + let mut lua = pkg_config::Config::new() + .range_version((Bound::Included(incl_bound), Bound::Excluded(excl_bound))) + .cargo_metadata(true) + .probe(if cfg!(feature = "luajit") { "luajit" } else { "lua" }); + + if lua.is_err() { + for pkg in alt_probe { + lua = pkg_config::Config::new().cargo_metadata(true).probe(pkg); + + if lua.is_ok() { + break; + } + } + } + + lua.unwrap_or_else(|err| panic!("cannot find Lua{ver} using `pkg-config`: {err}")); +} diff --git a/build/find_vendored.rs b/mlua-sys/build/find_vendored.rs similarity index 52% rename from build/find_vendored.rs rename to mlua-sys/build/find_vendored.rs index 5b7b8e1a..78105250 100644 --- a/build/find_vendored.rs +++ b/mlua-sys/build/find_vendored.rs @@ -1,28 +1,32 @@ #![allow(dead_code)] -use std::path::PathBuf; +pub fn probe_lua() { + #[cfg(feature = "lua55")] + let artifacts = lua_src::Build::new().build(lua_src::Lua55); -pub fn probe_lua() -> Option { #[cfg(feature = "lua54")] let artifacts = lua_src::Build::new().build(lua_src::Lua54); + #[cfg(feature = "lua53")] let artifacts = lua_src::Build::new().build(lua_src::Lua53); + #[cfg(feature = "lua52")] let artifacts = lua_src::Build::new().build(lua_src::Lua52); + #[cfg(feature = "lua51")] let artifacts = lua_src::Build::new().build(lua_src::Lua51); + #[cfg(feature = "luajit")] - let artifacts = { - let mut builder = luajit_src::Build::new(); - if cfg!(feature = "luajit52") { - builder.lua52compat(true); - } - builder.build() - }; + let artifacts = luajit_src::Build::new() + .lua52compat(cfg!(feature = "luajit52")) + .build(); + #[cfg(feature = "luau")] - let artifacts = luau0_src::Build::new().build(); + let artifacts = luau0_src::Build::new() + .enable_codegen(cfg!(feature = "luau-codegen")) + .set_max_cstack_size(1000000) + .set_vector_size(if cfg!(feature = "luau-vector4") { 4 } else { 3 }) + .build(); artifacts.print_cargo_metadata(); - - Some(artifacts.include_dir().to_owned()) } diff --git a/mlua-sys/build/main.rs b/mlua-sys/build/main.rs new file mode 100644 index 00000000..5bb8ddbe --- /dev/null +++ b/mlua-sys/build/main.rs @@ -0,0 +1,21 @@ +cfg_if::cfg_if! { + if #[cfg(all(feature = "lua55", not(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51", feature = "luajit", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "lua54", not(any(feature = "lua55", feature = "lua53", feature = "lua52", feature = "lua51", feature = "luajit", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "lua53", not(any(feature = "lua55", feature = "lua54", feature = "lua52", feature = "lua51", feature = "luajit", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "lua52", not(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua51", feature = "luajit", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "lua51", not(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "luajit", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "luajit", not(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51", feature = "luau"))))] { + include!("main_inner.rs"); + } else if #[cfg(all(feature = "luau", not(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "lua51", feature = "luajit"))))] { + include!("main_inner.rs"); + } else { + fn main() { + compile_error!("You can enable only one of the features: lua55, lua54, lua53, lua52, lua51, luajit, luajit52, luau"); + } + } +} diff --git a/mlua-sys/build/main_inner.rs b/mlua-sys/build/main_inner.rs new file mode 100644 index 00000000..5d6c04ed --- /dev/null +++ b/mlua-sys/build/main_inner.rs @@ -0,0 +1,42 @@ +use std::env; + +cfg_if::cfg_if! { + if #[cfg(any(feature = "luau", feature = "vendored"))] { + #[path = "find_vendored.rs"] + mod find; + } else { + #[path = "find_normal.rs"] + mod find; + } +} + +fn main() { + #[cfg(all(feature = "luau", feature = "module", windows))] + compile_error!("Luau does not support `module` mode on Windows"); + + #[cfg(any( + all(feature = "vendored", any(feature = "external", feature = "module")), + all(feature = "external", any(feature = "vendored", feature = "module")), + all(feature = "module", any(feature = "vendored", feature = "external")) + ))] + compile_error!("`vendored`, `external` and `module` features are mutually exclusive"); + + println!("cargo:rerun-if-changed=build"); + + // Check if compilation and linking is handled by external crate + if cfg!(not(feature = "external")) { + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap(); + if target_os == "windows" && cfg!(feature = "module") { + if !std::env::var("LUA_LIB_NAME").unwrap_or_default().is_empty() { + // Don't use raw-dylib linking + find::probe_lua(); + return; + } + + println!("cargo:rustc-cfg=raw_dylib"); + } + + #[cfg(not(feature = "module"))] + find::probe_lua(); + } +} diff --git a/mlua-sys/src/lib.rs b/mlua-sys/src/lib.rs new file mode 100644 index 00000000..e6672c88 --- /dev/null +++ b/mlua-sys/src/lib.rs @@ -0,0 +1,115 @@ +//! Low level bindings to Lua 5.5/5.4/5.3/5.2/5.1 (including LuaJIT) and Luau. + +#![allow(non_camel_case_types, non_snake_case)] +#![allow(clippy::missing_safety_doc)] +#![allow(unsafe_op_in_unsafe_fn)] +#![doc(test(attr(deny(warnings))))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +use std::os::raw::c_int; + +#[cfg(any(feature = "lua55", doc))] +pub use lua55::*; + +#[cfg(any(feature = "lua54", doc))] +pub use lua54::*; + +#[cfg(any(feature = "lua53", doc))] +pub use lua53::*; + +#[cfg(any(feature = "lua52", doc))] +pub use lua52::*; + +#[cfg(any(feature = "lua51", feature = "luajit", doc))] +pub use lua51::*; + +#[cfg(any(feature = "luau", doc))] +pub use luau::*; + +#[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] +#[doc(hidden)] +pub const LUA_MAX_UPVALUES: c_int = 255; + +#[cfg(any(feature = "lua51", feature = "luajit"))] +#[doc(hidden)] +pub const LUA_MAX_UPVALUES: c_int = 60; + +#[cfg(feature = "luau")] +#[doc(hidden)] +pub const LUA_MAX_UPVALUES: c_int = 200; + +// I believe `luaL_traceback` < 5.4 requires this much free stack to not error. +// 5.4 uses `luaL_Buffer` +#[doc(hidden)] +pub const LUA_TRACEBACK_STACK: c_int = 11; + +// The minimum alignment guaranteed by the architecture. +// Copied from https://github.com/rust-lang/rust/blob/main/library/std/src/sys/alloc/mod.rs +#[doc(hidden)] +#[rustfmt::skip] +pub const SYS_MIN_ALIGN: usize = if cfg!(any( + all(target_arch = "riscv32", any(target_os = "espidf", target_os = "zkvm")), + all(target_arch = "xtensa", target_os = "espidf"), +)) { + // The allocator on the esp-idf and zkvm platforms guarantees 4 byte alignment. + 4 +} else if cfg!(any( + target_arch = "x86", + target_arch = "arm", + target_arch = "m68k", + target_arch = "csky", + target_arch = "loongarch32", + target_arch = "mips", + target_arch = "mips32r6", + target_arch = "powerpc", + target_arch = "powerpc64", + target_arch = "sparc", + target_arch = "wasm32", + target_arch = "hexagon", + target_arch = "riscv32", + target_arch = "xtensa", +)) { + 8 +} else if cfg!(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "arm64ec", + target_arch = "loongarch64", + target_arch = "mips64", + target_arch = "mips64r6", + target_arch = "s390x", + target_arch = "sparc64", + target_arch = "riscv64", + target_arch = "wasm64", +)) { + 16 +} else { + panic!("no value for SYS_MIN_ALIGN") +}; + +#[macro_use] +mod macros; + +#[cfg(any(feature = "lua55", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "lua55")))] +pub mod lua55; + +#[cfg(any(feature = "lua54", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] +pub mod lua54; + +#[cfg(any(feature = "lua53", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "lua53")))] +pub mod lua53; + +#[cfg(any(feature = "lua52", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "lua52")))] +pub mod lua52; + +#[cfg(any(feature = "lua51", feature = "luajit", doc))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "lua51", feature = "luajit"))))] +pub mod lua51; + +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +pub mod luau; diff --git a/src/ffi/lua51/compat.rs b/mlua-sys/src/lua51/compat.rs similarity index 72% rename from src/ffi/lua51/compat.rs rename to mlua-sys/src/lua51/compat.rs index 9cd46fb6..5dc02546 100644 --- a/src/ffi/lua51/compat.rs +++ b/mlua-sys/src/lua51/compat.rs @@ -2,10 +2,9 @@ //! //! Based on github.com/keplerproject/lua-compat-5.3 -use std::convert::TryInto; -use std::mem; +use std::ffi::CStr; use std::os::raw::{c_char, c_int, c_void}; -use std::ptr; +use std::{mem, ptr}; use super::lauxlib::*; use super::lua::*; @@ -22,8 +21,8 @@ unsafe fn compat53_reverse(L: *mut lua_State, mut a: c_int, mut b: c_int) { } } -const COMPAT53_LEVELS1: c_int = 12; // size of the first part of the stack -const COMPAT53_LEVELS2: c_int = 10; // size of the second part of the stack +const COMPAT53_LEVELS1: c_int = 10; // size of the first part of the stack +const COMPAT53_LEVELS2: c_int = 11; // size of the second part of the stack unsafe fn compat53_countlevels(L: *mut lua_State) -> c_int { let mut ar: lua_Debug = mem::zeroed(); @@ -56,11 +55,7 @@ unsafe fn compat53_checkmode( while *st != 0 && *st != c { st = st.offset(1); } - if *st == c { - st - } else { - ptr::null() - } + if *st == c { st } else { ptr::null() } } if !mode.is_null() && strchr(mode, *modename).is_null() { @@ -90,11 +85,10 @@ unsafe fn compat53_findfield(L: *mut lua_State, objidx: c_int, level: c_int) -> lua_pop(L, 1); // remove value (but keep name) return 1; } else if compat53_findfield(L, objidx, level - 1) != 0 { - // try recursively - lua_remove(L, -2); // remove table (but keep name) - lua_pushliteral(L, "."); - lua_insert(L, -2); // place '.' between the two names - lua_concat(L, 3); + // stack: lib_name, lib_table, field_name (top) + lua_pushliteral(L, c"."); // place '.' between the two names + lua_replace(L, -3); // (in the slot occupied by table) + lua_concat(L, 3); // lib_name.field_name return 1; } } @@ -103,13 +97,20 @@ unsafe fn compat53_findfield(L: *mut lua_State, objidx: c_int, level: c_int) -> 0 // not found } -unsafe fn compat53_pushglobalfuncname(L: *mut lua_State, ar: *mut lua_Debug) -> c_int { +unsafe fn compat53_pushglobalfuncname(L: *mut lua_State, L1: *mut lua_State, ar: *mut lua_Debug) -> c_int { let top = lua_gettop(L); - lua_getinfo(L, cstr!("f"), ar); // push function + lua_getinfo(L1, cstr!("f"), ar); // push function + lua_xmove(L1, L, 1); // and move onto L lua_pushvalue(L, LUA_GLOBALSINDEX); + luaL_checkstack(L, 6, cstr!("not enough stack")); // slots for 'findfield' if compat53_findfield(L, top + 1, 2) != 0 { + let name = lua_tostring(L, -1); + if CStr::from_ptr(name).to_bytes().starts_with(b"_G.") { + lua_pushstring(L, name.add(3)); // push name without prefix + lua_remove(L, -2); // remove original name + } lua_copy(L, -1, top + 1); // move name to proper place - lua_pop(L, 2); // remove pushed values + lua_settop(L, top + 1); // remove pushed values 1 } else { lua_settop(L, top); // remove function and global table @@ -117,27 +118,23 @@ unsafe fn compat53_pushglobalfuncname(L: *mut lua_State, ar: *mut lua_Debug) -> } } -unsafe fn compat53_pushfuncname(L: *mut lua_State, ar: *mut lua_Debug) { - if *(*ar).namewhat != b'\0' as c_char { - // is there a name? - lua_pushfstring(L, cstr!("function '%s'"), (*ar).name); +unsafe fn compat53_pushfuncname(L: *mut lua_State, L1: *mut lua_State, ar: *mut lua_Debug) { + // try first a global name + if compat53_pushglobalfuncname(L, L1, ar) != 0 { + lua_pushfstring(L, cstr!("function '%s'"), lua_tostring(L, -1)); + lua_remove(L, -2); // remove name + } else if *(*ar).namewhat != b'\0' as c_char { + // use name from code + lua_pushfstring(L, cstr!("%s '%s'"), (*ar).namewhat, (*ar).name); } else if *(*ar).what == b'm' as c_char { // main? - lua_pushliteral(L, "main chunk"); - } else if *(*ar).what == b'C' as c_char { - if compat53_pushglobalfuncname(L, ar) != 0 { - lua_pushfstring(L, cstr!("function '%s'"), lua_tostring(L, -1)); - lua_remove(L, -2); // remove name - } else { - lua_pushliteral(L, "?"); - } + lua_pushliteral(L, c"main chunk"); + } else if *(*ar).what != b'C' as c_char { + // for Lua functions, use + let short_src = (*ar).short_src.as_ptr(); + lua_pushfstring(L, cstr!("function <%s:%d>"), short_src, (*ar).linedefined); } else { - lua_pushfstring( - L, - cstr!("function <%s:%d>"), - (*ar).short_src.as_ptr(), - (*ar).linedefined, - ); + lua_pushliteral(L, c"?"); } } @@ -178,7 +175,7 @@ pub unsafe fn lua_rotate(L: *mut lua_State, mut idx: c_int, mut n: c_int) { #[inline(always)] pub unsafe fn lua_copy(L: *mut lua_State, fromidx: c_int, toidx: c_int) { let abs_to = lua_absindex(L, toidx); - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); lua_pushvalue(L, fromidx); lua_replace(L, abs_to); } @@ -188,7 +185,8 @@ pub unsafe fn lua_isinteger(L: *mut lua_State, idx: c_int) -> c_int { if lua_type(L, idx) == LUA_TNUMBER { let n = lua_tonumber(L, idx); let i = lua_tointeger(L, idx); - if (n - i as lua_Number).abs() < lua_Number::EPSILON { + // Lua 5.3+ returns "false" for `-0.0` + if n.to_bits() == (i as lua_Number).to_bits() { return 1; } } @@ -316,7 +314,7 @@ pub unsafe fn lua_rawseti(L: *mut lua_State, idx: c_int, n: lua_Integer) { #[inline(always)] pub unsafe fn lua_rawsetp(L: *mut lua_State, idx: c_int, p: *const c_void) { let abs_i = lua_absindex(L, idx); - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); lua_pushlightuserdata(L, p as *mut c_void); lua_insert(L, -2); lua_rawset(L, abs_i); @@ -329,12 +327,7 @@ pub unsafe fn lua_setuservalue(L: *mut lua_State, idx: c_int) { } #[inline(always)] -pub unsafe fn lua_dump( - L: *mut lua_State, - writer: lua_Writer, - data: *mut c_void, - _strip: c_int, -) -> c_int { +pub unsafe fn lua_dump(L: *mut lua_State, writer: lua_Writer, data: *mut c_void, _strip: c_int) -> c_int { lua_dump_(L, writer, data) } @@ -366,12 +359,7 @@ pub unsafe fn lua_pushglobaltable(L: *mut lua_State) { } #[inline(always)] -pub unsafe fn lua_resume( - L: *mut lua_State, - _from: *mut lua_State, - narg: c_int, - nres: *mut c_int, -) -> c_int { +pub unsafe fn lua_resume(L: *mut lua_State, _from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int { let ret = lua_resume_(L, narg); if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { *nres = lua_gettop(L); @@ -389,7 +377,7 @@ pub unsafe fn luaL_checkstack(L: *mut lua_State, sz: c_int, msg: *const c_char) if !msg.is_null() { luaL_error(L, cstr!("stack overflow (%s)"), msg); } else { - lua_pushliteral(L, "stack overflow"); + lua_pushliteral(L, c"stack overflow"); lua_error(L); } } @@ -415,6 +403,25 @@ pub unsafe fn luaL_newmetatable(L: *mut lua_State, tname: *const c_char) -> c_in } } +pub unsafe fn luaL_loadbufferenv( + L: *mut lua_State, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, + mut env: c_int, +) -> c_int { + if env != 0 { + env = lua_absindex(L, env); + } + let status = luaL_loadbufferx(L, data, size, name, mode); + if status == LUA_OK && env != 0 { + lua_pushvalue(L, env); + lua_setfenv(L, -2); + } + status +} + #[inline(always)] pub unsafe fn luaL_loadbufferx( L: *mut lua_State, @@ -437,7 +444,7 @@ pub unsafe fn luaL_loadbufferx( #[inline(always)] pub unsafe fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer { let mut isnum = 0; - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); lua_len(L, idx); let res = lua_tointegerx(L, -1, &mut isnum); lua_pop(L, 1); @@ -447,63 +454,62 @@ pub unsafe fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer { res } -pub unsafe fn luaL_traceback( - L: *mut lua_State, - L1: *mut lua_State, - msg: *const c_char, - mut level: c_int, -) { +pub unsafe fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, mut level: c_int) { let mut ar: lua_Debug = mem::zeroed(); let top = lua_gettop(L); let numlevels = compat53_countlevels(L1); - let mark = if numlevels > COMPAT53_LEVELS1 + COMPAT53_LEVELS2 { - COMPAT53_LEVELS1 - } else { - 0 - }; + #[rustfmt::skip] + let mut limit = if numlevels - level > COMPAT53_LEVELS1 + COMPAT53_LEVELS2 { COMPAT53_LEVELS1 } else { -1 }; if !msg.is_null() { lua_pushfstring(L, cstr!("%s\n"), msg); } - lua_pushliteral(L, "stack traceback:"); + lua_pushliteral(L, c"stack traceback:"); while lua_getstack(L1, level, &mut ar) != 0 { - level += 1; - if level == mark { + if limit == 0 { // too many levels? - lua_pushliteral(L, "\n\t..."); // add a '...' - level = numlevels - COMPAT53_LEVELS2; // and skip to last ones + let n = numlevels - level - COMPAT53_LEVELS2; + // add warning about skip ("n + 1" because we skip current level too) + lua_pushfstring(L, cstr!("\n\t...\t(skipping %d levels)"), n + 1); // add warning about skip + level += n; // and skip to last levels } else { - lua_getinfo(L1, cstr!("Slnt"), &mut ar); - lua_pushfstring(L, cstr!("\n\t%s:"), ar.short_src.as_ptr()); - if ar.currentline > 0 { - lua_pushfstring(L, cstr!("%d:"), ar.currentline); + lua_getinfo(L1, cstr!("Sln"), &mut ar); + if *ar.what != b't' as c_char { + if ar.currentline <= 0 { + lua_pushfstring(L, cstr!("\n\t%s: in "), ar.short_src.as_ptr()); + } else { + lua_pushfstring(L, cstr!("\n\t%s:%d: in "), ar.short_src.as_ptr(), ar.currentline); + } + compat53_pushfuncname(L, L1, &mut ar); + lua_concat(L, lua_gettop(L) - top); + } else { + lua_pushstring(L, cstr!("\n\t(...tail calls...)")); } - lua_pushliteral(L, " in "); - compat53_pushfuncname(L, &mut ar); - lua_concat(L, lua_gettop(L) - top); } + level += 1; + limit -= 1; } lua_concat(L, lua_gettop(L) - top); } -pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char { +pub unsafe fn luaL_tolstring(L: *mut lua_State, mut idx: c_int, len: *mut usize) -> *const c_char { + idx = lua_absindex(L, idx); if luaL_callmeta(L, idx, cstr!("__tostring")) == 0 { - let t = lua_type(L, idx); - match t { + match lua_type(L, idx) { LUA_TNIL => { - lua_pushliteral(L, "nil"); + lua_pushliteral(L, c"nil"); } LUA_TSTRING | LUA_TNUMBER => { lua_pushvalue(L, idx); } LUA_TBOOLEAN => { if lua_toboolean(L, idx) == 0 { - lua_pushliteral(L, "false"); + lua_pushliteral(L, c"false"); } else { - lua_pushliteral(L, "true"); + lua_pushliteral(L, c"true"); } } - _ => { + t => { let tt = luaL_getmetafield(L, idx, cstr!("__name")); let name = if tt == LUA_TSTRING { lua_tostring(L, -1) @@ -512,7 +518,7 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> }; lua_pushfstring(L, cstr!("%s: %p"), name, lua_topointer(L, idx)); if tt != LUA_TNIL { - lua_replace(L, -2); + lua_replace(L, -2); // remove '__name' } } }; @@ -524,14 +530,14 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> #[inline(always)] pub unsafe fn luaL_setmetatable(L: *mut lua_State, tname: *const c_char) { - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); luaL_getmetatable(L, tname); lua_setmetatable(L, -2); } pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_char) -> c_int { let abs_i = lua_absindex(L, idx); - luaL_checkstack(L, 3, cstr!("not enough stack slots")); + luaL_checkstack(L, 3, cstr!("not enough stack slots available")); lua_pushstring_(L, fname); if lua_gettable(L, abs_i) == LUA_TTABLE { return 1; @@ -544,14 +550,9 @@ pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_ch 0 } -pub unsafe fn luaL_requiref( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, -) { +pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) { luaL_checkstack(L, 3, cstr!("not enough stack slots available")); - luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED")); + luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE); if lua_getfield(L, -1, modname) == LUA_TNIL { lua_pop(L, 1); lua_pushcfunction(L, openf); @@ -568,11 +569,10 @@ pub unsafe fn luaL_requiref( lua_getfield(L, -1, modname); } } - if cfg!(feature = "lua51") && glb != 0 { + if glb != 0 { lua_pushvalue(L, -1); lua_setglobal(L, modname); - } - if cfg!(feature = "luajit") && glb == 0 { + } else { lua_pushnil(L); lua_setglobal(L, modname); } diff --git a/src/ffi/lua51/lauxlib.rs b/mlua-sys/src/lua51/lauxlib.rs similarity index 66% rename from src/ffi/lua51/lauxlib.rs rename to mlua-sys/src/lua51/lauxlib.rs index cdce05e1..767d8fe0 100644 --- a/src/ffi/lua51/lauxlib.rs +++ b/mlua-sys/src/lua51/lauxlib.rs @@ -8,13 +8,17 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; // Extra error code for 'luaL_load' pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); + #[repr(C)] pub struct luaL_Reg { pub name: *const c_char, pub func: lua_CFunction, } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_register(L: *mut lua_State, libname: *const c_char, l: *const luaL_Reg); #[link_name = "luaL_getmetafield"] pub fn luaL_getmetafield_(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; @@ -42,7 +46,7 @@ extern "C" { pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; pub fn luaL_where(L: *mut lua_State, lvl: c_int); - pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> !; + pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> c_int; pub fn luaL_checkoption( L: *mut lua_State, @@ -56,17 +60,13 @@ extern "C" { pub const LUA_NOREF: c_int = -2; pub const LUA_REFNIL: c_int = -1; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_ref(L: *mut lua_State, t: c_int) -> c_int; pub fn luaL_unref(L: *mut lua_State, t: c_int, r#ref: c_int); pub fn luaL_loadfile(L: *mut lua_State, filename: *const c_char) -> c_int; - pub fn luaL_loadbuffer( - L: *mut lua_State, - buff: *const c_char, - sz: usize, - name: *const c_char, - ) -> c_int; + pub fn luaL_loadbuffer(L: *mut lua_State, buff: *const c_char, sz: usize, name: *const c_char) -> c_int; pub fn luaL_loadstring(L: *mut lua_State, s: *const c_char) -> c_int; pub fn luaL_newstate() -> *mut lua_State; @@ -107,8 +107,6 @@ pub unsafe fn luaL_optstring(L: *mut lua_State, n: c_int, d: *const c_char) -> * luaL_optlstring(L, n, d, ptr::null_mut()) } -// Deprecated from 5.3: luaL_checkint, luaL_optint, luaL_checklong, luaL_optlong - #[inline(always)] pub unsafe fn luaL_typename(L: *mut lua_State, i: c_int) -> *const c_char { lua::lua_typename(L, lua::lua_type(L, i)) @@ -138,8 +136,62 @@ pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) { lua::lua_getfield_(L, lua::LUA_REGISTRYINDEX, n); } -// TODO: luaL_opt +#[inline(always)] +pub unsafe fn luaL_opt( + L: *mut lua_State, + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } +} // -// TODO: Generic Buffer Manipulation +// Generic Buffer Manipulation // + +#[cfg(target_arch = "wasm32")] +const BUFSIZ: usize = 1024; // WASI libc's BUFSIZ is 1024 +#[cfg(not(target_arch = "wasm32"))] +const BUFSIZ: usize = libc::BUFSIZ as usize; + +// The buffer size used by the lauxlib buffer system. +// The "16384" workaround is taken from the LuaJIT source code. +pub const LUAL_BUFFERSIZE: usize = if BUFSIZ > 16384 { 8192 } else { BUFSIZ }; + +#[repr(C)] +pub struct luaL_Buffer { + pub p: *mut c_char, // current position in buffer + pub lvl: c_int, // number of strings in the stack + pub L: *mut lua_State, + pub buffer: [c_char; LUAL_BUFFERSIZE], +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Buffer); + pub fn luaL_prepbuffer(B: *mut luaL_Buffer) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Buffer, s: *const c_char, l: usize); + pub fn luaL_addstring(B: *mut luaL_Buffer, s: *const c_char); + pub fn luaL_addvalue(B: *mut luaL_Buffer); + pub fn luaL_pushresult(B: *mut luaL_Buffer); +} + +#[inline(always)] +pub unsafe fn luaL_addchar(B: *mut luaL_Buffer, c: c_char) { + let buffer_end = (*B).buffer.as_mut_ptr().add(LUAL_BUFFERSIZE); + if (*B).p >= buffer_end { + luaL_prepbuffer(B); + } + *(*B).p = c; + (*B).p = (*B).p.add(1); +} + +#[inline(always)] +pub unsafe fn luaL_addsize(B: *mut luaL_Buffer, n: usize) { + (*B).p = (*B).p.add(n); +} diff --git a/src/ffi/lua51/lua.rs b/mlua-sys/src/lua51/lua.rs similarity index 83% rename from src/ffi/lua51/lua.rs rename to mlua-sys/src/lua51/lua.rs index 3ad2db56..fbce6b96 100644 --- a/src/ffi/lua51/lua.rs +++ b/mlua-sys/src/lua51/lua.rs @@ -1,5 +1,6 @@ //! Contains definitions from `lua.h`. +use std::ffi::CStr; use std::marker::{PhantomData, PhantomPinned}; use std::os::raw::{c_char, c_double, c_int, c_void}; use std::ptr; @@ -67,26 +68,29 @@ pub const LUA_MINSTACK: c_int = 20; pub type lua_Number = c_double; /// A Lua integer, usually equivalent to `i64` -pub type lua_Integer = isize; +#[cfg(target_pointer_width = "32")] +pub type lua_Integer = i32; +#[cfg(target_pointer_width = "64")] +pub type lua_Integer = i64; /// Type for native C functions that can be passed to Lua. -pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; // Type for functions that read/write blocks when loading/dumping Lua chunks +#[rustfmt::skip] pub type lua_Reader = - unsafe extern "C" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; + unsafe extern "C-unwind" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; +#[rustfmt::skip] pub type lua_Writer = - unsafe extern "C" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; -/// Type for memory-allocation functions -pub type lua_Alloc = unsafe extern "C" fn( - ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, -) -> *mut c_void; +/// Type for memory-allocation functions (no unwinding) +#[rustfmt::skip] +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // State manipulation // @@ -94,9 +98,6 @@ extern "C" { pub fn lua_close(L: *mut lua_State); pub fn lua_newthread(L: *mut lua_State) -> *mut lua_State; - #[cfg(all(feature = "luajit", feature = "vendored"))] - pub fn lua_resetthread(L: *mut lua_State, th: *mut lua_State); - pub fn lua_atpanic(L: *mut lua_State, panicf: lua_CFunction) -> lua_CFunction; // @@ -218,21 +219,33 @@ pub const LUA_GCSTEP: c_int = 5; pub const LUA_GCSETPAUSE: c_int = 6; pub const LUA_GCSETSTEPMUL: c_int = 7; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_gc(L: *mut lua_State, what: c_int, data: c_int) -> c_int; } // // Miscellaneous functions // -extern "C" { - pub fn lua_error(L: *mut lua_State) -> !; +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + #[link_name = "lua_error"] + fn lua_error_(L: *mut lua_State) -> c_int; pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_concat(L: *mut lua_State, n: c_int); pub fn lua_getallocf(L: *mut lua_State, ud: *mut *mut c_void) -> lua_Alloc; pub fn lua_setallocf(L: *mut lua_State, f: lua_Alloc, ud: *mut c_void); } +// lua_error does not return but is declared to return int, and Rust translates +// ! to void which can cause link-time errors if the platform linker is aware +// of return types and requires they match (for example: wasm does this). +#[inline(always)] +pub unsafe fn lua_error(L: *mut lua_State) -> ! { + lua_error_(L); + unreachable!(); +} + // // Some useful macros (implemented as Rust functions) // @@ -257,7 +270,10 @@ pub unsafe fn lua_pushcfunction(L: *mut lua_State, f: lua_CFunction) { lua_pushcclosure(L, f, 0) } -// TODO: lua_strlen +#[inline(always)] +pub unsafe fn lua_strlen(L: *mut lua_State, i: c_int) -> usize { + lua_objlen(L, i) +} #[inline(always)] pub unsafe fn lua_isfunction(L: *mut lua_State, n: c_int) -> c_int { @@ -300,10 +316,8 @@ pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { } #[inline(always)] -pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static str) { - use std::ffi::CString; - let c_str = CString::new(s).unwrap(); - lua_pushlstring_(L, c_str.as_ptr(), c_str.as_bytes().len()) +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring_(L, s.as_ptr()); } #[inline(always)] @@ -316,11 +330,25 @@ pub unsafe fn lua_getglobal_(L: *mut lua_State, var: *const c_char) { lua_getfield_(L, LUA_GLOBALSINDEX, var) } +#[inline(always)] +pub unsafe fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void { + if lua_islightuserdata(L, idx) != 0 { + return lua_touserdata(L, idx); + } + ptr::null_mut() +} + #[inline(always)] pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { lua_tolstring(L, i, ptr::null_mut()) } +#[inline(always)] +pub unsafe fn lua_xpush(from: *mut lua_State, to: *mut lua_State, idx: c_int) { + lua_pushvalue(from, idx); + lua_xmove(from, to, 1); +} + // // Debug API // @@ -342,9 +370,10 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); /// Type for functions to be called on debug events. -pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; pub fn lua_getinfo(L: *mut lua_State, what: *const c_char, ar: *mut lua_Debug) -> c_int; pub fn lua_getlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; @@ -352,12 +381,7 @@ extern "C" { pub fn lua_getupvalue(L: *mut lua_State, funcindex: c_int, n: c_int) -> *const c_char; pub fn lua_setupvalue(L: *mut lua_State, funcindex: c_int, n: c_int) -> *const c_char; - pub fn lua_sethook( - L: *mut lua_State, - func: Option, - mask: c_int, - count: c_int, - ) -> c_int; + pub fn lua_sethook(L: *mut lua_State, func: Option, mask: c_int, count: c_int) -> c_int; pub fn lua_gethook(L: *mut lua_State) -> Option; pub fn lua_gethookmask(L: *mut lua_State) -> c_int; pub fn lua_gethookcount(L: *mut lua_State) -> c_int; diff --git a/src/ffi/lua51/lualib.rs b/mlua-sys/src/lua51/lualib.rs similarity index 54% rename from src/ffi/lua51/lualib.rs rename to mlua-sys/src/lua51/lualib.rs index 2165221e..3ed0242f 100644 --- a/src/ffi/lua51/lualib.rs +++ b/mlua-sys/src/lua51/lualib.rs @@ -1,26 +1,27 @@ //! Contains definitions from `lualib.h`. -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; use super::lua::lua_State; -pub const LUA_COLIBNAME: &str = "coroutine"; -pub const LUA_TABLIBNAME: &str = "table"; -pub const LUA_IOLIBNAME: &str = "io"; -pub const LUA_OSLIBNAME: &str = "os"; -pub const LUA_STRLIBNAME: &str = "string"; -pub const LUA_MATHLIBNAME: &str = "math"; -pub const LUA_DBLIBNAME: &str = "debug"; -pub const LUA_LOADLIBNAME: &str = "package"; +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_IOLIBNAME: *const c_char = cstr!("io"); +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_LOADLIBNAME: *const c_char = cstr!("package"); #[cfg(feature = "luajit")] -pub const LUA_BITLIBNAME: &str = "bit"; +pub const LUA_BITLIBNAME: *const c_char = cstr!("bit"); #[cfg(feature = "luajit")] -pub const LUA_JITLIBNAME: &str = "jit"; +pub const LUA_JITLIBNAME: *const c_char = cstr!("jit"); #[cfg(feature = "luajit")] -pub const LUA_FFILIBNAME: &str = "ffi"; +pub const LUA_FFILIBNAME: *const c_char = cstr!("ffi"); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua51", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaopen_base(L: *mut lua_State) -> c_int; pub fn luaopen_table(L: *mut lua_State) -> c_int; pub fn luaopen_io(L: *mut lua_State) -> c_int; diff --git a/src/ffi/lua51/mod.rs b/mlua-sys/src/lua51/mod.rs similarity index 100% rename from src/ffi/lua51/mod.rs rename to mlua-sys/src/lua51/mod.rs diff --git a/src/ffi/lua52/compat.rs b/mlua-sys/src/lua52/compat.rs similarity index 81% rename from src/ffi/lua52/compat.rs rename to mlua-sys/src/lua52/compat.rs index b834ab13..4cddd436 100644 --- a/src/ffi/lua52/compat.rs +++ b/mlua-sys/src/lua52/compat.rs @@ -2,7 +2,6 @@ //! //! Based on github.com/keplerproject/lua-compat-5.3 -use std::convert::TryInto; use std::os::raw::{c_char, c_int, c_void}; use std::ptr; @@ -52,7 +51,8 @@ pub unsafe fn lua_isinteger(L: *mut lua_State, idx: c_int) -> c_int { if lua_type(L, idx) == LUA_TNUMBER { let n = lua_tonumber(L, idx); let i = lua_tointeger(L, idx); - if (n - i as lua_Number).abs() < lua_Number::EPSILON { + // Lua 5.3+ returns "false" for `-0.0` + if n.to_bits() == (i as lua_Number).to_bits() { return 1; } } @@ -125,7 +125,7 @@ pub unsafe fn lua_rawget(L: *mut lua_State, idx: c_int) -> c_int { #[inline(always)] pub unsafe fn lua_rawgeti(L: *mut lua_State, idx: c_int, n: lua_Integer) -> c_int { - let n = n.try_into().expect("cannot convert index to lua_Integer"); + let n = n.try_into().expect("cannot convert index to c_int"); lua_rawgeti_(L, idx, n); lua_type(L, -1) } @@ -153,27 +153,17 @@ pub unsafe fn lua_seti(L: *mut lua_State, mut idx: c_int, n: lua_Integer) { #[inline(always)] pub unsafe fn lua_rawseti(L: *mut lua_State, idx: c_int, n: lua_Integer) { - let n = n.try_into().expect("cannot convert index from lua_Integer"); + let n = n.try_into().expect("cannot convert index to c_int"); lua_rawseti_(L, idx, n) } #[inline(always)] -pub unsafe fn lua_dump( - L: *mut lua_State, - writer: lua_Writer, - data: *mut c_void, - _strip: c_int, -) -> c_int { +pub unsafe fn lua_dump(L: *mut lua_State, writer: lua_Writer, data: *mut c_void, _strip: c_int) -> c_int { lua_dump_(L, writer, data) } #[inline(always)] -pub unsafe fn lua_resume( - L: *mut lua_State, - from: *mut lua_State, - narg: c_int, - nres: *mut c_int, -) -> c_int { +pub unsafe fn lua_resume(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int { let ret = lua_resume_(L, from, narg); if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { *nres = lua_gettop(L); @@ -205,24 +195,24 @@ pub unsafe fn luaL_newmetatable(L: *mut lua_State, tname: *const c_char) -> c_in } } -pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char { +pub unsafe fn luaL_tolstring(L: *mut lua_State, mut idx: c_int, len: *mut usize) -> *const c_char { + idx = lua_absindex(L, idx); if luaL_callmeta(L, idx, cstr!("__tostring")) == 0 { - let t = lua_type(L, idx); - match t { + match lua_type(L, idx) { LUA_TNIL => { - lua_pushliteral(L, "nil"); + lua_pushliteral(L, c"nil"); } LUA_TSTRING | LUA_TNUMBER => { lua_pushvalue(L, idx); } LUA_TBOOLEAN => { if lua_toboolean(L, idx) == 0 { - lua_pushliteral(L, "false"); + lua_pushliteral(L, c"false"); } else { - lua_pushliteral(L, "true"); + lua_pushliteral(L, c"true"); } } - _ => { + t => { let tt = luaL_getmetafield(L, idx, cstr!("__name")); let name = if tt == LUA_TSTRING { lua_tostring(L, -1) @@ -231,7 +221,7 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> }; lua_pushfstring(L, cstr!("%s: %p"), name, lua_topointer(L, idx)); if tt != LUA_TNIL { - lua_replace(L, -2); + lua_replace(L, -2); // remove '__name' } } }; @@ -241,14 +231,9 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> lua_tolstring(L, -1, len) } -pub unsafe fn luaL_requiref( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, -) { +pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) { luaL_checkstack(L, 3, cstr!("not enough stack slots available")); - luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED")); + luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE); if lua_getfield(L, -1, modname) == LUA_TNIL { lua_pop(L, 1); lua_pushcfunction(L, openf); @@ -263,3 +248,22 @@ pub unsafe fn luaL_requiref( } lua_replace(L, -2); } + +pub unsafe fn luaL_loadbufferenv( + L: *mut lua_State, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, + mut env: c_int, +) -> c_int { + if env != 0 { + env = lua_absindex(L, env); + } + let status = luaL_loadbufferx(L, data, size, name, mode); + if status == LUA_OK && env != 0 { + lua_pushvalue(L, env); + lua_setupvalue(L, -2, 1); + } + status +} diff --git a/src/ffi/lua52/lauxlib.rs b/mlua-sys/src/lua52/lauxlib.rs similarity index 63% rename from src/ffi/lua52/lauxlib.rs rename to mlua-sys/src/lua52/lauxlib.rs index 6b16aeb7..eb954a46 100644 --- a/src/ffi/lua52/lauxlib.rs +++ b/mlua-sys/src/lua52/lauxlib.rs @@ -8,13 +8,20 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State, lua_Un // Extra error code for 'luaL_load' pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); + +// Key, in the registry, for table of preloaded loaders +pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD"); + #[repr(C)] pub struct luaL_Reg { pub name: *const c_char, pub func: lua_CFunction, } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_checkversion_(L: *mut lua_State, ver: lua_Number); #[link_name = "luaL_getmetafield"] @@ -24,12 +31,8 @@ extern "C" { pub fn luaL_tolstring_(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; pub fn luaL_argerror(L: *mut lua_State, arg: c_int, extramsg: *const c_char) -> c_int; pub fn luaL_checklstring(L: *mut lua_State, arg: c_int, l: *mut usize) -> *const c_char; - pub fn luaL_optlstring( - L: *mut lua_State, - arg: c_int, - def: *const c_char, - l: *mut usize, - ) -> *const c_char; + pub fn luaL_optlstring(L: *mut lua_State, arg: c_int, def: *const c_char, l: *mut usize) + -> *const c_char; pub fn luaL_checknumber(L: *mut lua_State, arg: c_int) -> lua_Number; pub fn luaL_optnumber(L: *mut lua_State, arg: c_int, def: lua_Number) -> lua_Number; pub fn luaL_checkinteger(L: *mut lua_State, arg: c_int) -> lua_Integer; @@ -48,7 +51,7 @@ extern "C" { pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; pub fn luaL_where(L: *mut lua_State, lvl: c_int); - pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> !; + pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> c_int; pub fn luaL_checkoption( L: *mut lua_State, @@ -65,12 +68,12 @@ extern "C" { pub const LUA_NOREF: c_int = -2; pub const LUA_REFNIL: c_int = -1; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_ref(L: *mut lua_State, t: c_int) -> c_int; pub fn luaL_unref(L: *mut lua_State, t: c_int, r#ref: c_int); - pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) - -> c_int; + pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) -> c_int; } #[inline(always)] @@ -78,7 +81,8 @@ pub unsafe fn luaL_loadfile(L: *mut lua_State, f: *const c_char) -> c_int { luaL_loadfilex(L, f, ptr::null()) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_loadbufferx( L: *mut lua_State, buff: *const c_char, @@ -106,12 +110,7 @@ extern "C" { pub fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, level: c_int); #[link_name = "luaL_requiref"] - pub fn luaL_requiref_( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, - ); + pub fn luaL_requiref_(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int); } // @@ -167,18 +166,76 @@ pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) { lua::lua_getfield_(L, lua::LUA_REGISTRYINDEX, n); } -// luaL_opt would be implemented here but it is undocumented, so it's omitted +#[inline(always)] +pub unsafe fn luaL_loadbuffer(L: *mut lua_State, s: *const c_char, sz: usize, n: *const c_char) -> c_int { + luaL_loadbufferx(L, s, sz, n, ptr::null()) +} #[inline(always)] -pub unsafe fn luaL_loadbuffer( +pub unsafe fn luaL_opt( L: *mut lua_State, - s: *const c_char, - sz: usize, - n: *const c_char, -) -> c_int { - luaL_loadbufferx(L, s, sz, n, ptr::null()) + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } } // -// TODO: Generic Buffer Manipulation +// Generic Buffer Manipulation // + +#[cfg(target_arch = "wasm32")] +const BUFSIZ: usize = 1024; // WASI libc's BUFSIZ is 1024 +#[cfg(not(target_arch = "wasm32"))] +const BUFSIZ: usize = libc::BUFSIZ as usize; + +// The buffer size used by the lauxlib buffer system. +// The "16384" workaround is taken from the LuaJIT source code. +pub const LUAL_BUFFERSIZE: usize = if BUFSIZ > 16384 { 8192 } else { BUFSIZ }; + +#[repr(C)] +pub struct luaL_Buffer { + pub b: *mut c_char, // buffer address + pub size: usize, // buffer size + pub n: usize, // number of characters in buffer + pub L: *mut lua_State, + pub initb: [c_char; LUAL_BUFFERSIZE], // initial buffer space +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Buffer); + pub fn luaL_prepbuffsize(B: *mut luaL_Buffer, sz: usize) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Buffer, s: *const c_char, l: usize); + pub fn luaL_addstring(B: *mut luaL_Buffer, s: *const c_char); + pub fn luaL_addvalue(B: *mut luaL_Buffer); + pub fn luaL_pushresult(B: *mut luaL_Buffer); + pub fn luaL_pushresultsize(B: *mut luaL_Buffer, sz: usize); + pub fn luaL_buffinitsize(L: *mut lua_State, B: *mut luaL_Buffer, sz: usize) -> *mut c_char; +} + +// Macro implementations as inline functions + +#[inline(always)] +pub unsafe fn luaL_prepbuffer(B: *mut luaL_Buffer) -> *mut c_char { + luaL_prepbuffsize(B, LUAL_BUFFERSIZE) +} + +#[inline(always)] +pub unsafe fn luaL_addchar(B: *mut luaL_Buffer, c: c_char) { + if (*B).n >= (*B).size { + luaL_prepbuffsize(B, 1); + } + *(*B).b.add((*B).n) = c; + (*B).n += 1; +} + +#[inline(always)] +pub unsafe fn luaL_addsize(B: *mut luaL_Buffer, n: usize) { + (*B).n += n; +} diff --git a/src/ffi/lua52/lua.rs b/mlua-sys/src/lua52/lua.rs similarity index 84% rename from src/ffi/lua52/lua.rs rename to mlua-sys/src/lua52/lua.rs index 44097da6..e5239cee 100644 --- a/src/ffi/lua52/lua.rs +++ b/mlua-sys/src/lua52/lua.rs @@ -1,5 +1,6 @@ //! Contains definitions from `lua.h`. +use std::ffi::CStr; use std::marker::{PhantomData, PhantomPinned}; use std::os::raw::{c_char, c_double, c_int, c_uchar, c_uint, c_void}; use std::ptr; @@ -69,29 +70,32 @@ pub const LUA_RIDX_LAST: lua_Integer = LUA_RIDX_GLOBALS; pub type lua_Number = c_double; /// A Lua integer, usually equivalent to `i64` -pub type lua_Integer = isize; +#[cfg(target_pointer_width = "32")] +pub type lua_Integer = i32; +#[cfg(target_pointer_width = "64")] +pub type lua_Integer = i64; /// A Lua unsigned integer, equivalent to `u32` in Lua 5.2 pub type lua_Unsigned = c_uint; /// Type for native C functions that can be passed to Lua -pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; // Type for functions that read/write blocks when loading/dumping Lua chunks +#[rustfmt::skip] pub type lua_Reader = - unsafe extern "C" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; + unsafe extern "C-unwind" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; +#[rustfmt::skip] pub type lua_Writer = - unsafe extern "C" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; -/// Type for memory-allocation functions -pub type lua_Alloc = unsafe extern "C" fn( - ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, -) -> *mut c_void; +/// Type for memory-allocation functions (no unwinding) +#[rustfmt::skip] +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // State manipulation // @@ -152,20 +156,19 @@ pub const LUA_OPMOD: c_int = 4; pub const LUA_OPPOW: c_int = 5; pub const LUA_OPUNM: c_int = 6; -extern "C" { - pub fn lua_arith(L: *mut lua_State, op: c_int); -} - pub const LUA_OPEQ: c_int = 0; pub const LUA_OPLT: c_int = 1; pub const LUA_OPLE: c_int = 2; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_arith(L: *mut lua_State, op: c_int); pub fn lua_rawequal(L: *mut lua_State, idx1: c_int, idx2: c_int) -> c_int; pub fn lua_compare(L: *mut lua_State, idx1: c_int, idx2: c_int, op: c_int) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Push functions (C -> stack) // @@ -220,13 +223,7 @@ extern "C" { // // 'load' and 'call' functions (load and run Lua code) // - pub fn lua_callk( - L: *mut lua_State, - nargs: c_int, - nresults: c_int, - ctx: c_int, - k: Option, - ); + pub fn lua_callk(L: *mut lua_State, nargs: c_int, nresults: c_int, ctx: c_int, k: Option); pub fn lua_pcallk( L: *mut lua_State, nargs: c_int, @@ -259,16 +256,12 @@ pub unsafe fn lua_pcall(L: *mut lua_State, n: c_int, r: c_int, f: c_int) -> c_in lua_pcallk(L, n, r, f, 0, None) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Coroutine functions // - pub fn lua_yieldk( - L: *mut lua_State, - nresults: c_int, - ctx: c_int, - k: Option, - ) -> c_int; + pub fn lua_yieldk(L: *mut lua_State, nresults: c_int, ctx: c_int, k: Option) -> c_int; #[link_name = "lua_resume"] pub fn lua_resume_(L: *mut lua_State, from: *mut lua_State, narg: c_int) -> c_int; pub fn lua_status(L: *mut lua_State) -> c_int; @@ -295,15 +288,18 @@ pub const LUA_GCISRUNNING: c_int = 9; pub const LUA_GCGEN: c_int = 10; pub const LUA_GCINC: c_int = 11; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_gc(L: *mut lua_State, what: c_int, data: c_int) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Miscellaneous functions // - pub fn lua_error(L: *mut lua_State) -> !; + #[link_name = "lua_error"] + fn lua_error_(L: *mut lua_State) -> c_int; pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_concat(L: *mut lua_State, n: c_int); pub fn lua_len(L: *mut lua_State, idx: c_int); @@ -311,6 +307,15 @@ extern "C" { pub fn lua_setallocf(L: *mut lua_State, f: lua_Alloc, ud: *mut c_void); } +// lua_error does not return but is declared to return int, and Rust translates +// ! to void which can cause link-time errors if the platform linker is aware +// of return types and requires they match (for example: wasm does this). +#[inline(always)] +pub unsafe fn lua_error(L: *mut lua_State) -> ! { + lua_error_(L); + unreachable!(); +} + // // Some useful macros (implemented as Rust functions) // @@ -391,10 +396,8 @@ pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { } #[inline(always)] -pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static str) -> *const c_char { - use std::ffi::CString; - let c_str = CString::new(s).unwrap(); - lua_pushlstring_(L, c_str.as_ptr(), c_str.as_bytes().len()) +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring(L, s.as_ptr()); } #[inline(always)] @@ -402,11 +405,25 @@ pub unsafe fn lua_pushglobaltable(L: *mut lua_State) { lua_rawgeti_(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS as _) } +#[inline(always)] +pub unsafe fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void { + if lua_islightuserdata(L, idx) != 0 { + return lua_touserdata(L, idx); + } + ptr::null_mut() +} + #[inline(always)] pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { lua_tolstring(L, i, ptr::null_mut()) } +#[inline(always)] +pub unsafe fn lua_xpush(from: *mut lua_State, to: *mut lua_State, idx: c_int) { + lua_pushvalue(from, idx); + lua_xmove(from, to, 1); +} + // // Debug API // @@ -428,9 +445,10 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); /// Type for functions to be called on debug events. -pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; pub fn lua_getinfo(L: *mut lua_State, what: *const c_char, ar: *mut lua_Debug) -> c_int; pub fn lua_getlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; @@ -441,7 +459,7 @@ extern "C" { pub fn lua_upvalueid(L: *mut lua_State, fidx: c_int, n: c_int) -> *mut c_void; pub fn lua_upvaluejoin(L: *mut lua_State, fidx1: c_int, n1: c_int, fidx2: c_int, n2: c_int); - pub fn lua_sethook(L: *mut lua_State, func: Option, mask: c_int, count: c_int); + pub fn lua_sethook(L: *mut lua_State, func: Option, mask: c_int, count: c_int) -> c_int; pub fn lua_gethook(L: *mut lua_State) -> Option; pub fn lua_gethookmask(L: *mut lua_State) -> c_int; pub fn lua_gethookcount(L: *mut lua_State) -> c_int; diff --git a/src/ffi/lua53/lualib.rs b/mlua-sys/src/lua52/lualib.rs similarity index 51% rename from src/ffi/lua53/lualib.rs rename to mlua-sys/src/lua52/lualib.rs index 1cedf759..1b3c8445 100644 --- a/src/ffi/lua53/lualib.rs +++ b/mlua-sys/src/lua52/lualib.rs @@ -1,28 +1,27 @@ //! Contains definitions from `lualib.h`. -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; use super::lua::lua_State; -pub const LUA_COLIBNAME: &str = "coroutine"; -pub const LUA_TABLIBNAME: &str = "table"; -pub const LUA_IOLIBNAME: &str = "io"; -pub const LUA_OSLIBNAME: &str = "os"; -pub const LUA_STRLIBNAME: &str = "string"; -pub const LUA_UTF8LIBNAME: &str = "utf8"; -pub const LUA_BITLIBNAME: &str = "bit32"; -pub const LUA_MATHLIBNAME: &str = "math"; -pub const LUA_DBLIBNAME: &str = "debug"; -pub const LUA_LOADLIBNAME: &str = "package"; +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_IOLIBNAME: *const c_char = cstr!("io"); +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_BITLIBNAME: *const c_char = cstr!("bit32"); +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_LOADLIBNAME: *const c_char = cstr!("package"); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua52", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaopen_base(L: *mut lua_State) -> c_int; pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; pub fn luaopen_table(L: *mut lua_State) -> c_int; pub fn luaopen_io(L: *mut lua_State) -> c_int; pub fn luaopen_os(L: *mut lua_State) -> c_int; pub fn luaopen_string(L: *mut lua_State) -> c_int; - pub fn luaopen_utf8(L: *mut lua_State) -> c_int; pub fn luaopen_bit32(L: *mut lua_State) -> c_int; pub fn luaopen_math(L: *mut lua_State) -> c_int; pub fn luaopen_debug(L: *mut lua_State) -> c_int; diff --git a/src/ffi/lua52/mod.rs b/mlua-sys/src/lua52/mod.rs similarity index 100% rename from src/ffi/lua52/mod.rs rename to mlua-sys/src/lua52/mod.rs diff --git a/mlua-sys/src/lua53/compat.rs b/mlua-sys/src/lua53/compat.rs new file mode 100644 index 00000000..30b936ea --- /dev/null +++ b/mlua-sys/src/lua53/compat.rs @@ -0,0 +1,34 @@ +//! MLua compatibility layer for Lua 5.3 + +use std::os::raw::{c_char, c_int}; + +use super::lauxlib::*; +use super::lua::*; + +#[inline(always)] +pub unsafe fn lua_resume(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int { + let ret = lua_resume_(L, from, narg); + if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { + *nres = lua_gettop(L); + } + ret +} + +pub unsafe fn luaL_loadbufferenv( + L: *mut lua_State, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, + mut env: c_int, +) -> c_int { + if env != 0 { + env = lua_absindex(L, env); + } + let status = luaL_loadbufferx(L, data, size, name, mode); + if status == LUA_OK && env != 0 { + lua_pushvalue(L, env); + lua_setupvalue(L, -2, 1); + } + status +} diff --git a/src/ffi/lua54/lauxlib.rs b/mlua-sys/src/lua53/lauxlib.rs similarity index 61% rename from src/ffi/lua54/lauxlib.rs rename to mlua-sys/src/lua53/lauxlib.rs index 5a6c173c..4483b9ab 100644 --- a/src/ffi/lua54/lauxlib.rs +++ b/mlua-sys/src/lua53/lauxlib.rs @@ -1,7 +1,7 @@ //! Contains definitions from `lauxlib.h`. use std::os::raw::{c_char, c_int, c_void}; -use std::ptr; +use std::{mem, ptr}; use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; @@ -9,10 +9,10 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; // Key, in the registry, for table of loaded modules -pub const LUA_LOADED_TABLE: &str = "_LOADED"; +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); // Key, in the registry, for table of preloaded loaders -pub const LUA_PRELOAD_TABLE: &str = "_PRELOAD"; +pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD"); #[repr(C)] pub struct luaL_Reg { @@ -20,20 +20,18 @@ pub struct luaL_Reg { pub func: lua_CFunction, } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_checkversion_(L: *mut lua_State, ver: lua_Number, sz: usize); pub fn luaL_getmetafield(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; pub fn luaL_callmeta(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; - pub fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; + #[link_name = "luaL_tolstring"] + pub fn luaL_tolstring_(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; pub fn luaL_argerror(L: *mut lua_State, arg: c_int, extramsg: *const c_char) -> c_int; pub fn luaL_checklstring(L: *mut lua_State, arg: c_int, l: *mut usize) -> *const c_char; - pub fn luaL_optlstring( - L: *mut lua_State, - arg: c_int, - def: *const c_char, - l: *mut usize, - ) -> *const c_char; + pub fn luaL_optlstring(L: *mut lua_State, arg: c_int, def: *const c_char, l: *mut usize) + -> *const c_char; pub fn luaL_checknumber(L: *mut lua_State, arg: c_int) -> lua_Number; pub fn luaL_optnumber(L: *mut lua_State, arg: c_int, def: lua_Number) -> lua_Number; pub fn luaL_checkinteger(L: *mut lua_State, arg: c_int) -> lua_Integer; @@ -49,7 +47,7 @@ extern "C" { pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; pub fn luaL_where(L: *mut lua_State, lvl: c_int); - pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> !; + pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> c_int; pub fn luaL_checkoption( L: *mut lua_State, @@ -66,12 +64,12 @@ extern "C" { pub const LUA_NOREF: c_int = -2; pub const LUA_REFNIL: c_int = -1; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_ref(L: *mut lua_State, t: c_int) -> c_int; pub fn luaL_unref(L: *mut lua_State, t: c_int, r#ref: c_int); - pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) - -> c_int; + pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) -> c_int; } #[inline(always)] @@ -79,7 +77,8 @@ pub unsafe fn luaL_loadfile(L: *mut lua_State, f: *const c_char) -> c_int { luaL_loadfilex(L, f, ptr::null()) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_loadbufferx( L: *mut lua_State, buff: *const c_char, @@ -93,8 +92,6 @@ extern "C" { pub fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer; - // TODO: luaL_addgsub - pub fn luaL_gsub( L: *mut lua_State, s: *const c_char, @@ -108,12 +105,7 @@ extern "C" { pub fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, level: c_int); - pub fn luaL_requiref( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, - ); + pub fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int); } // @@ -169,18 +161,77 @@ pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) { lua::lua_getfield(L, lua::LUA_REGISTRYINDEX, n); } -// luaL_opt would be implemented here but it is undocumented, so it's omitted +#[inline(always)] +pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char { + luaL_tolstring_(L, lua::lua_absindex(L, idx), len) +} + +#[inline(always)] +pub unsafe fn luaL_loadbuffer(L: *mut lua_State, s: *const c_char, sz: usize, n: *const c_char) -> c_int { + luaL_loadbufferx(L, s, sz, n, ptr::null()) +} #[inline(always)] -pub unsafe fn luaL_loadbuffer( +pub unsafe fn luaL_opt( L: *mut lua_State, - s: *const c_char, - sz: usize, - n: *const c_char, -) -> c_int { - luaL_loadbufferx(L, s, sz, n, ptr::null()) + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } } // -// TODO: Generic Buffer Manipulation +// Generic Buffer Manipulation // + +// The buffer size used by the lauxlib buffer system. +// In Lua 5.3: LUAL_BUFFERSIZE = (int)(0x80 * sizeof(void*) * sizeof(lua_Integer)) +#[rustfmt::skip] +pub const LUAL_BUFFERSIZE: usize = 0x80 * mem::size_of::<*const ()>() * mem::size_of::(); + +#[repr(C)] +pub struct luaL_Buffer { + pub b: *mut c_char, // buffer address + pub size: usize, // buffer size + pub n: usize, // number of characters in buffer + pub L: *mut lua_State, + pub initb: [c_char; LUAL_BUFFERSIZE], // initial buffer space +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Buffer); + pub fn luaL_prepbuffsize(B: *mut luaL_Buffer, sz: usize) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Buffer, s: *const c_char, l: usize); + pub fn luaL_addstring(B: *mut luaL_Buffer, s: *const c_char); + pub fn luaL_addvalue(B: *mut luaL_Buffer); + pub fn luaL_pushresult(B: *mut luaL_Buffer); + pub fn luaL_pushresultsize(B: *mut luaL_Buffer, sz: usize); + pub fn luaL_buffinitsize(L: *mut lua_State, B: *mut luaL_Buffer, sz: usize) -> *mut c_char; +} + +// Macro implementations as inline functions + +#[inline(always)] +pub unsafe fn luaL_prepbuffer(B: *mut luaL_Buffer) -> *mut c_char { + luaL_prepbuffsize(B, LUAL_BUFFERSIZE) +} + +#[inline(always)] +pub unsafe fn luaL_addchar(B: *mut luaL_Buffer, c: c_char) { + if (*B).n >= (*B).size { + luaL_prepbuffsize(B, 1); + } + *(*B).b.add((*B).n) = c; + (*B).n += 1; +} + +#[inline(always)] +pub unsafe fn luaL_addsize(B: *mut luaL_Buffer, n: usize) { + (*B).n += n; +} diff --git a/src/ffi/lua53/lua.rs b/mlua-sys/src/lua53/lua.rs similarity index 85% rename from src/ffi/lua53/lua.rs rename to mlua-sys/src/lua53/lua.rs index 753408c3..2729fdcd 100644 --- a/src/ffi/lua53/lua.rs +++ b/mlua-sys/src/lua53/lua.rs @@ -1,9 +1,9 @@ //! Contains definitions from `lua.h`. +use std::ffi::CStr; use std::marker::{PhantomData, PhantomPinned}; -use std::mem; use std::os::raw::{c_char, c_double, c_int, c_uchar, c_void}; -use std::ptr; +use std::{mem, ptr}; // Mark for precompiled code (`Lua`) pub const LUA_SIGNATURE: &[u8] = b"\x1bLua"; @@ -82,27 +82,27 @@ pub type lua_Unsigned = u64; pub type lua_KContext = isize; /// Type for native C functions that can be passed to Lua -pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; /// Type for continuation functions pub type lua_KFunction = - unsafe extern "C" fn(L: *mut lua_State, status: c_int, ctx: lua_KContext) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, status: c_int, ctx: lua_KContext) -> c_int; // Type for functions that read/write blocks when loading/dumping Lua chunks +#[rustfmt::skip] pub type lua_Reader = - unsafe extern "C" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; + unsafe extern "C-unwind" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; +#[rustfmt::skip] pub type lua_Writer = - unsafe extern "C" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; -/// Type for memory-allocation functions -pub type lua_Alloc = unsafe extern "C" fn( - ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, -) -> *mut c_void; +/// Type for memory-allocation functions (no unwinding) +#[rustfmt::skip] +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // State manipulation // @@ -167,20 +167,19 @@ pub const LUA_OPSHR: c_int = 11; pub const LUA_OPUNM: c_int = 12; pub const LUA_OPBNOT: c_int = 13; -extern "C" { - pub fn lua_arith(L: *mut lua_State, op: c_int); -} - pub const LUA_OPEQ: c_int = 0; pub const LUA_OPLT: c_int = 1; pub const LUA_OPLE: c_int = 2; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_arith(L: *mut lua_State, op: c_int); pub fn lua_rawequal(L: *mut lua_State, idx1: c_int, idx2: c_int) -> c_int; pub fn lua_compare(L: *mut lua_State, idx1: c_int, idx2: c_int, op: c_int) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Push functions (C -> stack) // @@ -252,12 +251,7 @@ extern "C" { mode: *const c_char, ) -> c_int; - pub fn lua_dump( - L: *mut lua_State, - writer: lua_Writer, - data: *mut c_void, - strip: c_int, - ) -> c_int; + pub fn lua_dump(L: *mut lua_State, writer: lua_Writer, data: *mut c_void, strip: c_int) -> c_int; } #[inline(always)] @@ -270,7 +264,8 @@ pub unsafe fn lua_pcall(L: *mut lua_State, n: c_int, r: c_int, f: c_int) -> c_in lua_pcallk(L, n, r, f, 0, None) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Coroutine functions // @@ -304,15 +299,18 @@ pub const LUA_GCSETPAUSE: c_int = 6; pub const LUA_GCSETSTEPMUL: c_int = 7; pub const LUA_GCISRUNNING: c_int = 9; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_gc(L: *mut lua_State, what: c_int, data: c_int) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Miscellaneous functions // - pub fn lua_error(L: *mut lua_State) -> !; + #[link_name = "lua_error"] + fn lua_error_(L: *mut lua_State) -> c_int; pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_concat(L: *mut lua_State, n: c_int); pub fn lua_len(L: *mut lua_State, idx: c_int); @@ -321,6 +319,15 @@ extern "C" { pub fn lua_setallocf(L: *mut lua_State, f: lua_Alloc, ud: *mut c_void); } +// lua_error does not return but is declared to return int, and Rust translates +// ! to void which can cause link-time errors if the platform linker is aware +// of return types and requires they match (for example: wasm does this). +#[inline(always)] +pub unsafe fn lua_error(L: *mut lua_State) -> ! { + lua_error_(L); + unreachable!(); +} + // // Some useful macros (implemented as Rust functions) // @@ -401,10 +408,8 @@ pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { } #[inline(always)] -pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static str) -> *const c_char { - use std::ffi::CString; - let c_str = CString::new(s).unwrap(); - lua_pushlstring(L, c_str.as_ptr(), c_str.as_bytes().len()) +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring(L, s.as_ptr()); } #[inline(always)] @@ -412,6 +417,14 @@ pub unsafe fn lua_pushglobaltable(L: *mut lua_State) -> c_int { lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS) } +#[inline(always)] +pub unsafe fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void { + if lua_islightuserdata(L, idx) != 0 { + return lua_touserdata(L, idx); + } + ptr::null_mut() +} + #[inline(always)] pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { lua_tolstring(L, i, ptr::null_mut()) @@ -434,6 +447,12 @@ pub unsafe fn lua_replace(L: *mut lua_State, idx: c_int) { lua_pop(L, 1) } +#[inline(always)] +pub unsafe fn lua_xpush(from: *mut lua_State, to: *mut lua_State, idx: c_int) { + lua_pushvalue(from, idx); + lua_xmove(from, to, 1); +} + // // Debug API // @@ -455,9 +474,10 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); /// Type for functions to be called on debug events. -pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; pub fn lua_getinfo(L: *mut lua_State, what: *const c_char, ar: *mut lua_Debug) -> c_int; pub fn lua_getlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; diff --git a/mlua-sys/src/lua53/lualib.rs b/mlua-sys/src/lua53/lualib.rs new file mode 100644 index 00000000..c29d8160 --- /dev/null +++ b/mlua-sys/src/lua53/lualib.rs @@ -0,0 +1,34 @@ +//! Contains definitions from `lualib.h`. + +use std::os::raw::{c_char, c_int}; + +use super::lua::lua_State; + +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_IOLIBNAME: *const c_char = cstr!("io"); +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_UTF8LIBNAME: *const c_char = cstr!("utf8"); +pub const LUA_BITLIBNAME: *const c_char = cstr!("bit32"); +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_LOADLIBNAME: *const c_char = cstr!("package"); + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua53", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaopen_base(L: *mut lua_State) -> c_int; + pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; + pub fn luaopen_table(L: *mut lua_State) -> c_int; + pub fn luaopen_io(L: *mut lua_State) -> c_int; + pub fn luaopen_os(L: *mut lua_State) -> c_int; + pub fn luaopen_string(L: *mut lua_State) -> c_int; + pub fn luaopen_utf8(L: *mut lua_State) -> c_int; + pub fn luaopen_bit32(L: *mut lua_State) -> c_int; + pub fn luaopen_math(L: *mut lua_State) -> c_int; + pub fn luaopen_debug(L: *mut lua_State) -> c_int; + pub fn luaopen_package(L: *mut lua_State) -> c_int; + + // open all builtin libraries + pub fn luaL_openlibs(L: *mut lua_State); +} diff --git a/src/ffi/lua53/mod.rs b/mlua-sys/src/lua53/mod.rs similarity index 100% rename from src/ffi/lua53/mod.rs rename to mlua-sys/src/lua53/mod.rs diff --git a/src/ffi/lua53/lauxlib.rs b/mlua-sys/src/lua54/lauxlib.rs similarity index 55% rename from src/ffi/lua53/lauxlib.rs rename to mlua-sys/src/lua54/lauxlib.rs index cc83c616..26b31bfe 100644 --- a/src/ffi/lua53/lauxlib.rs +++ b/mlua-sys/src/lua54/lauxlib.rs @@ -1,7 +1,7 @@ //! Contains definitions from `lauxlib.h`. -use std::os::raw::{c_char, c_int, c_void}; -use std::ptr; +use std::os::raw::{c_char, c_double, c_int, c_long, c_void}; +use std::{mem, ptr}; use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; @@ -9,10 +9,10 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; // Key, in the registry, for table of loaded modules -pub const LUA_LOADED_TABLE: &str = "_LOADED"; +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); // Key, in the registry, for table of preloaded loaders -pub const LUA_PRELOAD_TABLE: &str = "_PRELOAD"; +pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD"); #[repr(C)] pub struct luaL_Reg { @@ -20,7 +20,8 @@ pub struct luaL_Reg { pub func: lua_CFunction, } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_checkversion_(L: *mut lua_State, ver: lua_Number, sz: usize); pub fn luaL_getmetafield(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; @@ -28,12 +29,8 @@ extern "C" { pub fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; pub fn luaL_argerror(L: *mut lua_State, arg: c_int, extramsg: *const c_char) -> c_int; pub fn luaL_checklstring(L: *mut lua_State, arg: c_int, l: *mut usize) -> *const c_char; - pub fn luaL_optlstring( - L: *mut lua_State, - arg: c_int, - def: *const c_char, - l: *mut usize, - ) -> *const c_char; + pub fn luaL_optlstring(L: *mut lua_State, arg: c_int, def: *const c_char, l: *mut usize) + -> *const c_char; pub fn luaL_checknumber(L: *mut lua_State, arg: c_int) -> lua_Number; pub fn luaL_optnumber(L: *mut lua_State, arg: c_int, def: lua_Number) -> lua_Number; pub fn luaL_checkinteger(L: *mut lua_State, arg: c_int) -> lua_Integer; @@ -49,7 +46,7 @@ extern "C" { pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; pub fn luaL_where(L: *mut lua_State, lvl: c_int); - pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> !; + pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> c_int; pub fn luaL_checkoption( L: *mut lua_State, @@ -66,12 +63,12 @@ extern "C" { pub const LUA_NOREF: c_int = -2; pub const LUA_REFNIL: c_int = -1; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_ref(L: *mut lua_State, t: c_int) -> c_int; pub fn luaL_unref(L: *mut lua_State, t: c_int, r#ref: c_int); - pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) - -> c_int; + pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) -> c_int; } #[inline(always)] @@ -79,7 +76,8 @@ pub unsafe fn luaL_loadfile(L: *mut lua_State, f: *const c_char) -> c_int { luaL_loadfilex(L, f, ptr::null()) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaL_loadbufferx( L: *mut lua_State, buff: *const c_char, @@ -93,6 +91,8 @@ extern "C" { pub fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer; + pub fn luaL_addgsub(B: *mut luaL_Buffer, s: *const c_char, p: *const c_char, r: *const c_char); + pub fn luaL_gsub( L: *mut lua_State, s: *const c_char, @@ -106,12 +106,7 @@ extern "C" { pub fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, level: c_int); - pub fn luaL_requiref( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, - ); + pub fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int); } // @@ -167,18 +162,120 @@ pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) { lua::lua_getfield(L, lua::LUA_REGISTRYINDEX, n); } -// luaL_opt would be implemented here but it is undocumented, so it's omitted - #[inline(always)] -pub unsafe fn luaL_loadbuffer( +pub unsafe fn luaL_loadbuffer(L: *mut lua_State, s: *const c_char, sz: usize, n: *const c_char) -> c_int { + luaL_loadbufferx(L, s, sz, n, ptr::null()) +} + +pub unsafe fn luaL_loadbufferenv( L: *mut lua_State, - s: *const c_char, - sz: usize, - n: *const c_char, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, + mut env: c_int, ) -> c_int { - luaL_loadbufferx(L, s, sz, n, ptr::null()) + if env != 0 { + env = lua::lua_absindex(L, env); + } + let status = luaL_loadbufferx(L, data, size, name, mode); + if status == lua::LUA_OK && env != 0 { + lua::lua_pushvalue(L, env); + lua::lua_setupvalue(L, -2, 1); + } + status +} + +#[inline(always)] +pub unsafe fn luaL_opt( + L: *mut lua_State, + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } } // -// TODO: Generic Buffer Manipulation +// Generic Buffer Manipulation // + +// The buffer size used by the lauxlib buffer system. +// LUAL_BUFFERSIZE = (int)(16 * sizeof(void*) * sizeof(lua_Number)) +#[rustfmt::skip] +pub const LUAL_BUFFERSIZE: usize = 16 * mem::size_of::<*const ()>() * mem::size_of::(); + +// Union used for the initial buffer with maximum alignment. +// This ensures proper alignment for the buffer data. +#[repr(C)] +pub union luaL_BufferInit { + // Alignment matches LUAI_MAXALIGN + pub _align_n: lua_Number, + pub _align_u: c_double, + pub _align_s: *mut c_void, + pub _align_i: lua_Integer, + pub _align_l: c_long, + // Initial buffer space + pub b: [c_char; LUAL_BUFFERSIZE], +} + +#[repr(C)] +pub struct luaL_Buffer { + pub b: *mut c_char, // buffer address + pub size: usize, // buffer size + pub n: usize, // number of characters in buffer + pub L: *mut lua_State, + pub init: luaL_BufferInit, // initial buffer (union with alignment) +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Buffer); + pub fn luaL_prepbuffsize(B: *mut luaL_Buffer, sz: usize) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Buffer, s: *const c_char, l: usize); + pub fn luaL_addstring(B: *mut luaL_Buffer, s: *const c_char); + pub fn luaL_addvalue(B: *mut luaL_Buffer); + pub fn luaL_pushresult(B: *mut luaL_Buffer); + pub fn luaL_pushresultsize(B: *mut luaL_Buffer, sz: usize); + pub fn luaL_buffinitsize(L: *mut lua_State, B: *mut luaL_Buffer, sz: usize) -> *mut c_char; +} + +// Macro implementations as inline functions + +#[inline(always)] +pub unsafe fn luaL_prepbuffer(B: *mut luaL_Buffer) -> *mut c_char { + luaL_prepbuffsize(B, LUAL_BUFFERSIZE) +} + +#[inline(always)] +pub unsafe fn luaL_addchar(B: *mut luaL_Buffer, c: c_char) { + if (*B).n >= (*B).size { + luaL_prepbuffsize(B, 1); + } + *(*B).b.add((*B).n) = c; + (*B).n += 1; +} + +#[inline(always)] +pub unsafe fn luaL_addsize(B: *mut luaL_Buffer, n: usize) { + (*B).n += n; +} + +#[inline(always)] +pub unsafe fn luaL_buffsub(B: *mut luaL_Buffer, n: usize) { + (*B).n -= n; +} + +#[inline(always)] +pub unsafe fn luaL_bufflen(B: *mut luaL_Buffer) -> usize { + (*B).n +} + +#[inline(always)] +pub unsafe fn luaL_buffaddr(B: *mut luaL_Buffer) -> *mut c_char { + (*B).b +} diff --git a/src/ffi/lua54/lua.rs b/mlua-sys/src/lua54/lua.rs similarity index 82% rename from src/ffi/lua54/lua.rs rename to mlua-sys/src/lua54/lua.rs index 113f3247..15a30444 100644 --- a/src/ffi/lua54/lua.rs +++ b/mlua-sys/src/lua54/lua.rs @@ -1,9 +1,9 @@ //! Contains definitions from `lua.h`. +use std::ffi::CStr; use std::marker::{PhantomData, PhantomPinned}; -use std::mem; use std::os::raw::{c_char, c_double, c_int, c_uchar, c_ushort, c_void}; -use std::ptr; +use std::{mem, ptr}; // Mark for precompiled code (`Lua`) pub const LUA_SIGNATURE: &[u8] = b"\x1bLua"; @@ -81,38 +81,40 @@ pub type lua_Unsigned = u64; pub type lua_KContext = isize; /// Type for native C functions that can be passed to Lua -pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; /// Type for continuation functions pub type lua_KFunction = - unsafe extern "C" fn(L: *mut lua_State, status: c_int, ctx: lua_KContext) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, status: c_int, ctx: lua_KContext) -> c_int; // Type for functions that read/write blocks when loading/dumping Lua chunks +#[rustfmt::skip] pub type lua_Reader = - unsafe extern "C" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; + unsafe extern "C-unwind" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; +#[rustfmt::skip] pub type lua_Writer = - unsafe extern "C" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; + unsafe extern "C-unwind" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; -/// Type for memory-allocation functions -pub type lua_Alloc = unsafe extern "C" fn( - ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, -) -> *mut c_void; +/// Type for memory-allocation functions (no unwinding) +#[rustfmt::skip] +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; /// Type for warning functions -pub type lua_WarnFunction = - unsafe extern "C" fn(ud: *mut c_void, msg: *const c_char, tocont: c_int); +pub type lua_WarnFunction = unsafe extern "C-unwind" fn(ud: *mut c_void, msg: *const c_char, tocont: c_int); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // State manipulation // pub fn lua_newstate(f: lua_Alloc, ud: *mut c_void) -> *mut lua_State; pub fn lua_close(L: *mut lua_State); pub fn lua_newthread(L: *mut lua_State) -> *mut lua_State; + // Deprecated in Lua 5.4.6 pub fn lua_resetthread(L: *mut lua_State) -> c_int; + #[cfg(feature = "vendored")] + pub fn lua_closethread(L: *mut lua_State, from: *mut lua_State) -> c_int; pub fn lua_atpanic(L: *mut lua_State, panicf: lua_CFunction) -> lua_CFunction; @@ -146,13 +148,21 @@ extern "C" { pub fn lua_tointegerx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Integer; pub fn lua_toboolean(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; - pub fn lua_rawlen(L: *mut lua_State, idx: c_int) -> usize; + #[link_name = "lua_rawlen"] + fn lua_rawlen_(L: *mut lua_State, idx: c_int) -> lua_Unsigned; pub fn lua_tocfunction(L: *mut lua_State, idx: c_int) -> Option; pub fn lua_touserdata(L: *mut lua_State, idx: c_int) -> *mut c_void; pub fn lua_tothread(L: *mut lua_State, idx: c_int) -> *mut lua_State; pub fn lua_topointer(L: *mut lua_State, idx: c_int) -> *const c_void; } +// lua_rawlen's return type changed from size_t to lua_Unsigned int in Lua 5.4. +// This adapts the crate API to the new Lua ABI. +#[inline(always)] +pub unsafe fn lua_rawlen(L: *mut lua_State, idx: c_int) -> usize { + lua_rawlen_(L, idx) as usize +} + // // Comparison and arithmetic functions // @@ -171,20 +181,19 @@ pub const LUA_OPSHR: c_int = 11; pub const LUA_OPUNM: c_int = 12; pub const LUA_OPBNOT: c_int = 13; -extern "C" { - pub fn lua_arith(L: *mut lua_State, op: c_int); -} - pub const LUA_OPEQ: c_int = 0; pub const LUA_OPLT: c_int = 1; pub const LUA_OPLE: c_int = 2; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_arith(L: *mut lua_State, op: c_int); pub fn lua_rawequal(L: *mut lua_State, idx1: c_int, idx2: c_int) -> c_int; pub fn lua_compare(L: *mut lua_State, idx1: c_int, idx2: c_int, op: c_int) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Push functions (C -> stack) // @@ -256,12 +265,7 @@ extern "C" { mode: *const c_char, ) -> c_int; - pub fn lua_dump( - L: *mut lua_State, - writer: lua_Writer, - data: *mut c_void, - strip: c_int, - ) -> c_int; + pub fn lua_dump(L: *mut lua_State, writer: lua_Writer, data: *mut c_void, strip: c_int) -> c_int; } #[inline(always)] @@ -274,7 +278,8 @@ pub unsafe fn lua_pcall(L: *mut lua_State, n: c_int, r: c_int, f: c_int) -> c_in lua_pcallk(L, n, r, f, 0, None) } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Coroutine functions // @@ -284,12 +289,7 @@ extern "C" { ctx: lua_KContext, k: Option, ) -> c_int; - pub fn lua_resume( - L: *mut lua_State, - from: *mut lua_State, - narg: c_int, - nres: *mut c_int, - ) -> c_int; + pub fn lua_resume(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int; pub fn lua_status(L: *mut lua_State) -> c_int; pub fn lua_isyieldable(L: *mut lua_State) -> c_int; } @@ -302,7 +302,8 @@ pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { // // Warning-related functions // -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_setwarnf(L: *mut lua_State, f: Option, ud: *mut c_void); pub fn lua_warning(L: *mut lua_State, msg: *const c_char, tocont: c_int); } @@ -322,15 +323,18 @@ pub const LUA_GCISRUNNING: c_int = 9; pub const LUA_GCGEN: c_int = 10; pub const LUA_GCINC: c_int = 11; -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_gc(L: *mut lua_State, what: c_int, ...) -> c_int; } -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { // // Miscellaneous functions // - pub fn lua_error(L: *mut lua_State) -> !; + #[link_name = "lua_error"] + fn lua_error_(L: *mut lua_State) -> c_int; pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_concat(L: *mut lua_State, n: c_int); pub fn lua_len(L: *mut lua_State, idx: c_int); @@ -342,6 +346,15 @@ extern "C" { pub fn lua_closeslot(L: *mut lua_State, idx: c_int); } +// lua_error does not return but is declared to return int, and Rust translates +// ! to void which can cause link-time errors if the platform linker is aware +// of return types and requires they match (for example: wasm does this). +#[inline(always)] +pub unsafe fn lua_error(L: *mut lua_State) -> ! { + lua_error_(L); + unreachable!(); +} + // // Some useful macros (implemented as Rust functions) // @@ -422,10 +435,8 @@ pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { } #[inline(always)] -pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static str) -> *const c_char { - use std::ffi::CString; - let c_str = CString::new(s).unwrap(); - lua_pushlstring(L, c_str.as_ptr(), c_str.as_bytes().len()) +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring(L, s.as_ptr()); } #[inline(always)] @@ -433,6 +444,14 @@ pub unsafe fn lua_pushglobaltable(L: *mut lua_State) -> c_int { lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS) } +#[inline(always)] +pub unsafe fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void { + if lua_islightuserdata(L, idx) != 0 { + return lua_touserdata(L, idx); + } + ptr::null_mut() +} + #[inline(always)] pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { lua_tolstring(L, i, ptr::null_mut()) @@ -455,6 +474,12 @@ pub unsafe fn lua_replace(L: *mut lua_State, idx: c_int) { lua_pop(L, 1) } +#[inline(always)] +pub unsafe fn lua_xpush(from: *mut lua_State, to: *mut lua_State, idx: c_int) { + lua_pushvalue(from, idx); + lua_xmove(from, to, 1); +} + #[inline(always)] pub unsafe fn lua_newuserdata(L: *mut lua_State, sz: usize) -> *mut c_void { lua_newuserdatauv(L, sz, 1) @@ -491,9 +516,10 @@ pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); /// Type for functions to be called on debug events. -pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; pub fn lua_getinfo(L: *mut lua_State, what: *const c_char, ar: *mut lua_Debug) -> c_int; pub fn lua_getlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; diff --git a/src/ffi/lua54/lualib.rs b/mlua-sys/src/lua54/lualib.rs similarity index 51% rename from src/ffi/lua54/lualib.rs rename to mlua-sys/src/lua54/lualib.rs index 3626004e..0f666643 100644 --- a/src/ffi/lua54/lualib.rs +++ b/mlua-sys/src/lua54/lualib.rs @@ -1,20 +1,21 @@ //! Contains definitions from `lualib.h`. -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; use super::lua::lua_State; -pub const LUA_COLIBNAME: &str = "coroutine"; -pub const LUA_TABLIBNAME: &str = "table"; -pub const LUA_IOLIBNAME: &str = "io"; -pub const LUA_OSLIBNAME: &str = "os"; -pub const LUA_STRLIBNAME: &str = "string"; -pub const LUA_UTF8LIBNAME: &str = "utf8"; -pub const LUA_MATHLIBNAME: &str = "math"; -pub const LUA_DBLIBNAME: &str = "debug"; -pub const LUA_LOADLIBNAME: &str = "package"; +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_IOLIBNAME: *const c_char = cstr!("io"); +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_UTF8LIBNAME: *const c_char = cstr!("utf8"); +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_LOADLIBNAME: *const c_char = cstr!("package"); -extern "C" { +#[cfg_attr(all(windows, raw_dylib), link(name = "lua54", kind = "raw-dylib"))] +unsafe extern "C-unwind" { pub fn luaopen_base(L: *mut lua_State) -> c_int; pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; pub fn luaopen_table(L: *mut lua_State) -> c_int; diff --git a/src/ffi/lua54/mod.rs b/mlua-sys/src/lua54/mod.rs similarity index 100% rename from src/ffi/lua54/mod.rs rename to mlua-sys/src/lua54/mod.rs diff --git a/mlua-sys/src/lua55/lauxlib.rs b/mlua-sys/src/lua55/lauxlib.rs new file mode 100644 index 00000000..13a8524f --- /dev/null +++ b/mlua-sys/src/lua55/lauxlib.rs @@ -0,0 +1,299 @@ +//! Contains definitions from `lauxlib.h`. + +use std::os::raw::{c_char, c_double, c_int, c_long, c_uint, c_void}; +use std::{mem, ptr}; + +use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; + +// Extra error code for 'luaL_loadfilex' +pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; + +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); + +// Key, in the registry, for table of preloaded loaders +pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD"); + +#[repr(C)] +pub struct luaL_Reg { + pub name: *const c_char, + pub func: lua_CFunction, +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_checkversion_(L: *mut lua_State, ver: lua_Number, sz: usize); + + pub fn luaL_getmetafield(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; + pub fn luaL_callmeta(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; + pub fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; + pub fn luaL_argerror(L: *mut lua_State, arg: c_int, extramsg: *const c_char) -> c_int; + pub fn luaL_checklstring(L: *mut lua_State, arg: c_int, l: *mut usize) -> *const c_char; + pub fn luaL_optlstring(L: *mut lua_State, arg: c_int, def: *const c_char, l: *mut usize) + -> *const c_char; + pub fn luaL_checknumber(L: *mut lua_State, arg: c_int) -> lua_Number; + pub fn luaL_optnumber(L: *mut lua_State, arg: c_int, def: lua_Number) -> lua_Number; + pub fn luaL_checkinteger(L: *mut lua_State, arg: c_int) -> lua_Integer; + pub fn luaL_optinteger(L: *mut lua_State, arg: c_int, def: lua_Integer) -> lua_Integer; + + pub fn luaL_checkstack(L: *mut lua_State, sz: c_int, msg: *const c_char); + pub fn luaL_checktype(L: *mut lua_State, arg: c_int, t: c_int); + pub fn luaL_checkany(L: *mut lua_State, arg: c_int); + + pub fn luaL_newmetatable(L: *mut lua_State, tname: *const c_char) -> c_int; + pub fn luaL_setmetatable(L: *mut lua_State, tname: *const c_char); + pub fn luaL_testudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; + pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; + + pub fn luaL_where(L: *mut lua_State, lvl: c_int); + pub fn luaL_error(L: *mut lua_State, fmt: *const c_char, ...) -> c_int; + + pub fn luaL_checkoption( + L: *mut lua_State, + arg: c_int, + def: *const c_char, + lst: *const *const c_char, + ) -> c_int; + + pub fn luaL_fileresult(L: *mut lua_State, stat: c_int, fname: *const c_char) -> c_int; + pub fn luaL_execresult(L: *mut lua_State, stat: c_int) -> c_int; + pub fn luaL_alloc(L: *mut lua_State, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; +} + +// Pre-defined references +pub const LUA_NOREF: c_int = -2; +pub const LUA_REFNIL: c_int = -1; + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_ref(L: *mut lua_State, t: c_int) -> c_int; + pub fn luaL_unref(L: *mut lua_State, t: c_int, r#ref: c_int); + + pub fn luaL_loadfilex(L: *mut lua_State, filename: *const c_char, mode: *const c_char) -> c_int; +} + +#[inline(always)] +pub unsafe fn luaL_loadfile(L: *mut lua_State, f: *const c_char) -> c_int { + luaL_loadfilex(L, f, ptr::null()) +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_loadbufferx( + L: *mut lua_State, + buff: *const c_char, + sz: usize, + name: *const c_char, + mode: *const c_char, + ) -> c_int; + pub fn luaL_loadstring(L: *mut lua_State, s: *const c_char) -> c_int; + + pub fn luaL_newstate() -> *mut lua_State; + + #[link_name = "luaL_makeseed"] + pub fn luaL_makeseed_(L: *mut lua_State) -> c_uint; + + pub fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer; + + pub fn luaL_addgsub(B: *mut luaL_Buffer, s: *const c_char, p: *const c_char, r: *const c_char); + + pub fn luaL_gsub( + L: *mut lua_State, + s: *const c_char, + p: *const c_char, + r: *const c_char, + ) -> *const c_char; + + pub fn luaL_setfuncs(L: *mut lua_State, l: *const luaL_Reg, nup: c_int); + + pub fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_char) -> c_int; + + pub fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, level: c_int); + + pub fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int); +} + +// +// Some useful macros (implemented as Rust functions) +// + +// TODO: luaL_newlibtable, luaL_newlib + +#[inline(always)] +pub unsafe fn luaL_argcheck(L: *mut lua_State, cond: c_int, arg: c_int, extramsg: *const c_char) { + if cond == 0 { + luaL_argerror(L, arg, extramsg); + } +} + +#[inline(always)] +pub unsafe fn luaL_checkstring(L: *mut lua_State, n: c_int) -> *const c_char { + luaL_checklstring(L, n, ptr::null_mut()) +} + +#[inline(always)] +pub unsafe fn luaL_optstring(L: *mut lua_State, n: c_int, d: *const c_char) -> *const c_char { + luaL_optlstring(L, n, d, ptr::null_mut()) +} + +#[inline(always)] +pub unsafe fn luaL_typename(L: *mut lua_State, i: c_int) -> *const c_char { + lua::lua_typename(L, lua::lua_type(L, i)) +} + +#[inline(always)] +pub unsafe fn luaL_dofile(L: *mut lua_State, filename: *const c_char) -> c_int { + let status = luaL_loadfile(L, filename); + if status == 0 { + lua::lua_pcall(L, 0, lua::LUA_MULTRET, 0) + } else { + status + } +} + +#[inline(always)] +pub unsafe fn luaL_dostring(L: *mut lua_State, s: *const c_char) -> c_int { + let status = luaL_loadstring(L, s); + if status == 0 { + lua::lua_pcall(L, 0, lua::LUA_MULTRET, 0) + } else { + status + } +} + +#[inline(always)] +pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) { + lua::lua_getfield(L, lua::LUA_REGISTRYINDEX, n); +} + +#[inline(always)] +pub unsafe fn luaL_loadbuffer(L: *mut lua_State, s: *const c_char, sz: usize, n: *const c_char) -> c_int { + luaL_loadbufferx(L, s, sz, n, ptr::null()) +} + +pub unsafe fn luaL_loadbufferenv( + L: *mut lua_State, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, + mut env: c_int, +) -> c_int { + if env != 0 { + env = lua::lua_absindex(L, env); + } + let status = luaL_loadbufferx(L, data, size, name, mode); + if status == lua::LUA_OK && env != 0 { + lua::lua_pushvalue(L, env); + lua::lua_setupvalue(L, -2, 1); + } + status +} + +pub unsafe fn luaL_makeseed(L: *mut lua_State) -> c_uint { + #[cfg(macos)] + return libc::arc4random(); + #[cfg(linux)] + { + let mut seed = 0u32; + let buf = &mut seed as *mut _ as *mut c_void; + if libc::getrandom(buf, 4, libc::GRND_NONBLOCK) == 4 { + return seed; + } + } + luaL_makeseed_(L) +} + +#[inline(always)] +pub unsafe fn luaL_opt( + L: *mut lua_State, + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } +} + +// +// Generic Buffer Manipulation +// + +// The buffer size used by the lauxlib buffer system. +// LUAL_BUFFERSIZE = (int)(16 * sizeof(void*) * sizeof(lua_Number)) +#[rustfmt::skip] +pub const LUAL_BUFFERSIZE: usize = 16 * mem::size_of::<*const ()>() * mem::size_of::(); + +// Union used for the initial buffer with maximum alignment. +// This ensures proper alignment for the buffer data. +#[repr(C)] +pub union luaL_BufferInit { + // Alignment matches LUAI_MAXALIGN + pub _align_n: lua_Number, + pub _align_u: c_double, + pub _align_s: *mut c_void, + pub _align_i: lua_Integer, + pub _align_l: c_long, + // Initial buffer space + pub b: [c_char; LUAL_BUFFERSIZE], +} + +#[repr(C)] +pub struct luaL_Buffer { + pub b: *mut c_char, // buffer address + pub size: usize, // buffer size + pub n: usize, // number of characters in buffer + pub L: *mut lua_State, + pub init: luaL_BufferInit, // initial buffer (union with alignment) +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Buffer); + pub fn luaL_prepbuffsize(B: *mut luaL_Buffer, sz: usize) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Buffer, s: *const c_char, l: usize); + pub fn luaL_addstring(B: *mut luaL_Buffer, s: *const c_char); + pub fn luaL_addvalue(B: *mut luaL_Buffer); + pub fn luaL_pushresult(B: *mut luaL_Buffer); + pub fn luaL_pushresultsize(B: *mut luaL_Buffer, sz: usize); + pub fn luaL_buffinitsize(L: *mut lua_State, B: *mut luaL_Buffer, sz: usize) -> *mut c_char; +} + +// Macro implementations as inline functions + +#[inline(always)] +pub unsafe fn luaL_prepbuffer(B: *mut luaL_Buffer) -> *mut c_char { + luaL_prepbuffsize(B, LUAL_BUFFERSIZE) +} + +#[inline(always)] +pub unsafe fn luaL_addchar(B: *mut luaL_Buffer, c: c_char) { + if (*B).n >= (*B).size { + luaL_prepbuffsize(B, 1); + } + *(*B).b.add((*B).n) = c; + (*B).n += 1; +} + +#[inline(always)] +pub unsafe fn luaL_addsize(B: *mut luaL_Buffer, n: usize) { + (*B).n += n; +} + +#[inline(always)] +pub unsafe fn luaL_buffsub(B: *mut luaL_Buffer, n: usize) { + (*B).n -= n; +} + +#[inline(always)] +pub unsafe fn luaL_bufflen(B: *mut luaL_Buffer) -> usize { + (*B).n +} + +#[inline(always)] +pub unsafe fn luaL_buffaddr(B: *mut luaL_Buffer) -> *mut c_char { + (*B).b +} diff --git a/mlua-sys/src/lua55/lua.rs b/mlua-sys/src/lua55/lua.rs new file mode 100644 index 00000000..dd6067d6 --- /dev/null +++ b/mlua-sys/src/lua55/lua.rs @@ -0,0 +1,578 @@ +//! Contains definitions from `lua.h`. + +use std::ffi::CStr; +use std::marker::{PhantomData, PhantomPinned}; +use std::os::raw::{c_char, c_double, c_int, c_uchar, c_uint, c_void}; +use std::{mem, ptr}; + +// Mark for precompiled code (`Lua`) +pub const LUA_SIGNATURE: &[u8] = b"\x1bLua"; + +// Option for multiple returns in 'lua_pcall' and 'lua_call' +pub const LUA_MULTRET: c_int = -1; + +// Size of the Lua stack +#[doc(hidden)] +pub const LUAI_MAXSTACK: c_int = c_int::MAX; + +// Size of a raw memory area associated with a Lua state with very fast access. +pub const LUA_EXTRASPACE: usize = mem::size_of::<*const ()>(); + +// +// Pseudo-indices +// +pub const LUA_REGISTRYINDEX: c_int = -(c_int::MAX / 2 + 1000); + +pub const fn lua_upvalueindex(i: c_int) -> c_int { + LUA_REGISTRYINDEX - i +} + +// +// Thread status +// +pub const LUA_OK: c_int = 0; +pub const LUA_YIELD: c_int = 1; +pub const LUA_ERRRUN: c_int = 2; +pub const LUA_ERRSYNTAX: c_int = 3; +pub const LUA_ERRMEM: c_int = 4; +pub const LUA_ERRERR: c_int = 5; + +/// A raw Lua state associated with a thread. +#[repr(C)] +pub struct lua_State { + _data: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +// +// Basic types +// +pub const LUA_TNONE: c_int = -1; + +pub const LUA_TNIL: c_int = 0; +pub const LUA_TBOOLEAN: c_int = 1; +pub const LUA_TLIGHTUSERDATA: c_int = 2; +pub const LUA_TNUMBER: c_int = 3; +pub const LUA_TSTRING: c_int = 4; +pub const LUA_TTABLE: c_int = 5; +pub const LUA_TFUNCTION: c_int = 6; +pub const LUA_TUSERDATA: c_int = 7; +pub const LUA_TTHREAD: c_int = 8; + +pub const LUA_NUMTYPES: c_int = 9; + +/// Minimum Lua stack available to a C function +pub const LUA_MINSTACK: c_int = 20; + +// Predefined values in the registry +// index 1 is reserved for the reference mechanism +pub const LUA_RIDX_GLOBALS: lua_Integer = 2; +pub const LUA_RIDX_MAINTHREAD: lua_Integer = 3; +pub const LUA_RIDX_LAST: lua_Integer = 3; + +/// A Lua number, usually equivalent to `f64` +pub type lua_Number = c_double; + +/// A Lua integer, usually equivalent to `i64` +pub type lua_Integer = i64; + +/// A Lua unsigned integer, usually equivalent to `u64` +pub type lua_Unsigned = u64; + +/// Type for continuation-function contexts +pub type lua_KContext = isize; + +/// Type for native C functions that can be passed to Lua +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; + +/// Type for continuation functions +pub type lua_KFunction = + unsafe extern "C-unwind" fn(L: *mut lua_State, status: c_int, ctx: lua_KContext) -> c_int; + +// Type for functions that read/write blocks when loading/dumping Lua chunks +#[rustfmt::skip] +pub type lua_Reader = + unsafe extern "C-unwind" fn(L: *mut lua_State, ud: *mut c_void, sz: *mut usize) -> *const c_char; +#[rustfmt::skip] +pub type lua_Writer = + unsafe extern "C-unwind" fn(L: *mut lua_State, p: *const c_void, sz: usize, ud: *mut c_void) -> c_int; + +/// Type for memory-allocation functions (no unwinding) +#[rustfmt::skip] +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; + +/// Type for warning functions +pub type lua_WarnFunction = unsafe extern "C-unwind" fn(ud: *mut c_void, msg: *const c_char, tocont: c_int); + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + // + // State manipulation + // + pub fn lua_newstate(f: lua_Alloc, ud: *mut c_void, seed: c_uint) -> *mut lua_State; + pub fn lua_close(L: *mut lua_State); + pub fn lua_newthread(L: *mut lua_State) -> *mut lua_State; + pub fn lua_closethread(L: *mut lua_State, from: *mut lua_State) -> c_int; + + pub fn lua_atpanic(L: *mut lua_State, panicf: lua_CFunction) -> lua_CFunction; + + pub fn lua_version(L: *mut lua_State) -> lua_Number; + + // + // Basic stack manipulation + // + pub fn lua_absindex(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_gettop(L: *mut lua_State) -> c_int; + pub fn lua_settop(L: *mut lua_State, idx: c_int); + pub fn lua_pushvalue(L: *mut lua_State, idx: c_int); + pub fn lua_rotate(L: *mut lua_State, idx: c_int, n: c_int); + pub fn lua_copy(L: *mut lua_State, fromidx: c_int, toidx: c_int); + pub fn lua_checkstack(L: *mut lua_State, sz: c_int) -> c_int; + + pub fn lua_xmove(from: *mut lua_State, to: *mut lua_State, n: c_int); + + // + // Access functions (stack -> C) + // + pub fn lua_isnumber(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_isstring(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_iscfunction(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_isinteger(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_isuserdata(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_type(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_typename(L: *mut lua_State, tp: c_int) -> *const c_char; + + pub fn lua_tonumberx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Number; + pub fn lua_tointegerx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Integer; + pub fn lua_toboolean(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; + #[link_name = "lua_rawlen"] + fn lua_rawlen_(L: *mut lua_State, idx: c_int) -> lua_Unsigned; + pub fn lua_tocfunction(L: *mut lua_State, idx: c_int) -> Option; + pub fn lua_touserdata(L: *mut lua_State, idx: c_int) -> *mut c_void; + pub fn lua_tothread(L: *mut lua_State, idx: c_int) -> *mut lua_State; + pub fn lua_topointer(L: *mut lua_State, idx: c_int) -> *const c_void; +} + +// lua_rawlen's return type changed from size_t to lua_Unsigned int in Lua 5.4. +// This adapts the crate API to the new Lua ABI. +#[inline(always)] +pub unsafe fn lua_rawlen(L: *mut lua_State, idx: c_int) -> usize { + lua_rawlen_(L, idx) as usize +} + +// +// Comparison and arithmetic functions +// +pub const LUA_OPADD: c_int = 0; +pub const LUA_OPSUB: c_int = 1; +pub const LUA_OPMUL: c_int = 2; +pub const LUA_OPMOD: c_int = 3; +pub const LUA_OPPOW: c_int = 4; +pub const LUA_OPDIV: c_int = 5; +pub const LUA_OPIDIV: c_int = 6; +pub const LUA_OPBAND: c_int = 7; +pub const LUA_OPBOR: c_int = 8; +pub const LUA_OPBXOR: c_int = 9; +pub const LUA_OPSHL: c_int = 10; +pub const LUA_OPSHR: c_int = 11; +pub const LUA_OPUNM: c_int = 12; +pub const LUA_OPBNOT: c_int = 13; + +pub const LUA_OPEQ: c_int = 0; +pub const LUA_OPLT: c_int = 1; +pub const LUA_OPLE: c_int = 2; + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_arith(L: *mut lua_State, op: c_int); + pub fn lua_rawequal(L: *mut lua_State, idx1: c_int, idx2: c_int) -> c_int; + pub fn lua_compare(L: *mut lua_State, idx1: c_int, idx2: c_int, op: c_int) -> c_int; +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + // + // Push functions (C -> stack) + // + pub fn lua_pushnil(L: *mut lua_State); + pub fn lua_pushnumber(L: *mut lua_State, n: lua_Number); + pub fn lua_pushinteger(L: *mut lua_State, n: lua_Integer); + pub fn lua_pushlstring(L: *mut lua_State, s: *const c_char, len: usize) -> *const c_char; + pub fn lua_pushexternalstring( + L: *mut lua_State, + s: *const c_char, + len: usize, + falloc: Option, + ud: *mut c_void, + ) -> *const c_char; + pub fn lua_pushstring(L: *mut lua_State, s: *const c_char) -> *const c_char; + // lua_pushvfstring + pub fn lua_pushfstring(L: *mut lua_State, fmt: *const c_char, ...) -> *const c_char; + pub fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, n: c_int); + pub fn lua_pushboolean(L: *mut lua_State, b: c_int); + pub fn lua_pushlightuserdata(L: *mut lua_State, p: *mut c_void); + pub fn lua_pushthread(L: *mut lua_State) -> c_int; + + // + // Get functions (Lua -> stack) + // + pub fn lua_getglobal(L: *mut lua_State, name: *const c_char) -> c_int; + pub fn lua_gettable(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_getfield(L: *mut lua_State, idx: c_int, k: *const c_char) -> c_int; + pub fn lua_geti(L: *mut lua_State, idx: c_int, n: lua_Integer) -> c_int; + pub fn lua_rawget(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_rawgeti(L: *mut lua_State, idx: c_int, n: lua_Integer) -> c_int; + pub fn lua_rawgetp(L: *mut lua_State, idx: c_int, p: *const c_void) -> c_int; + + pub fn lua_createtable(L: *mut lua_State, narr: c_int, nrec: c_int); + pub fn lua_newuserdatauv(L: *mut lua_State, sz: usize, nuvalue: c_int) -> *mut c_void; + pub fn lua_getmetatable(L: *mut lua_State, objindex: c_int) -> c_int; + pub fn lua_getiuservalue(L: *mut lua_State, idx: c_int, n: c_int) -> c_int; + + // + // Set functions (stack -> Lua) + // + pub fn lua_setglobal(L: *mut lua_State, name: *const c_char); + pub fn lua_settable(L: *mut lua_State, idx: c_int); + pub fn lua_setfield(L: *mut lua_State, idx: c_int, k: *const c_char); + pub fn lua_seti(L: *mut lua_State, idx: c_int, n: lua_Integer); + pub fn lua_rawset(L: *mut lua_State, idx: c_int); + pub fn lua_rawseti(L: *mut lua_State, idx: c_int, n: lua_Integer); + pub fn lua_rawsetp(L: *mut lua_State, idx: c_int, p: *const c_void); + pub fn lua_setmetatable(L: *mut lua_State, objindex: c_int) -> c_int; + pub fn lua_setiuservalue(L: *mut lua_State, idx: c_int, n: c_int) -> c_int; + + // + // 'load' and 'call' functions (load and run Lua code) + // + pub fn lua_callk( + L: *mut lua_State, + nargs: c_int, + nresults: c_int, + ctx: lua_KContext, + k: Option, + ); + pub fn lua_pcallk( + L: *mut lua_State, + nargs: c_int, + nresults: c_int, + errfunc: c_int, + ctx: lua_KContext, + k: Option, + ) -> c_int; + + pub fn lua_load( + L: *mut lua_State, + reader: lua_Reader, + data: *mut c_void, + chunkname: *const c_char, + mode: *const c_char, + ) -> c_int; + + pub fn lua_dump(L: *mut lua_State, writer: lua_Writer, data: *mut c_void, strip: c_int) -> c_int; +} + +#[inline(always)] +pub unsafe fn lua_call(L: *mut lua_State, n: c_int, r: c_int) { + lua_callk(L, n, r, 0, None) +} + +#[inline(always)] +pub unsafe fn lua_pcall(L: *mut lua_State, n: c_int, r: c_int, f: c_int) -> c_int { + lua_pcallk(L, n, r, f, 0, None) +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + // + // Coroutine functions + // + pub fn lua_yieldk( + L: *mut lua_State, + nresults: c_int, + ctx: lua_KContext, + k: Option, + ) -> c_int; + pub fn lua_resume(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int; + pub fn lua_status(L: *mut lua_State) -> c_int; + pub fn lua_isyieldable(L: *mut lua_State) -> c_int; +} + +#[inline(always)] +pub unsafe fn lua_yield(L: *mut lua_State, n: c_int) -> c_int { + lua_yieldk(L, n, 0, None) +} + +// +// Warning-related functions +// +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_setwarnf(L: *mut lua_State, f: Option, ud: *mut c_void); + pub fn lua_warning(L: *mut lua_State, msg: *const c_char, tocont: c_int); +} + +// +// Garbage-collection options +// +pub const LUA_GCSTOP: c_int = 0; +pub const LUA_GCRESTART: c_int = 1; +pub const LUA_GCCOLLECT: c_int = 2; +pub const LUA_GCCOUNT: c_int = 3; +pub const LUA_GCCOUNTB: c_int = 4; +pub const LUA_GCSTEP: c_int = 5; +pub const LUA_GCISRUNNING: c_int = 6; +pub const LUA_GCGEN: c_int = 7; +pub const LUA_GCINC: c_int = 8; +pub const LUA_GCPARAM: c_int = 9; + +// Parameters for GC generational mode +pub const LUA_GCPMINORMUL: c_int = 0; // control minor collections +pub const LUA_GCPMAJORMINOR: c_int = 1; // control shift major->minor +pub const LUA_GCPMINORMAJOR: c_int = 2; // control shift minor->major + +// Parameters for GC incremental mode +pub const LUA_GCPPAUSE: c_int = 3; // size of pause between successive GCs +pub const LUA_GCPSTEPMUL: c_int = 4; // GC "speed" +pub const LUA_GCPSTEPSIZE: c_int = 5; // GC granularity + +pub const LUA_GCPNUM: c_int = 6; // number of parameters + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_gc(L: *mut lua_State, what: c_int, ...) -> c_int; +} + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + // + // Miscellaneous functions + // + #[link_name = "lua_error"] + fn lua_error_(L: *mut lua_State) -> c_int; + pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_concat(L: *mut lua_State, n: c_int); + pub fn lua_len(L: *mut lua_State, idx: c_int); + pub fn lua_numbertocstring(L: *mut lua_State, idx: c_int, buff: *mut c_char) -> c_uint; + pub fn lua_stringtonumber(L: *mut lua_State, s: *const c_char) -> usize; + pub fn lua_getallocf(L: *mut lua_State, ud: *mut *mut c_void) -> lua_Alloc; + pub fn lua_setallocf(L: *mut lua_State, f: lua_Alloc, ud: *mut c_void); + + pub fn lua_toclose(L: *mut lua_State, idx: c_int); + pub fn lua_closeslot(L: *mut lua_State, idx: c_int); +} + +// lua_error does not return but is declared to return int, and Rust translates +// ! to void which can cause link-time errors if the platform linker is aware +// of return types and requires they match (for example: wasm does this). +#[inline(always)] +pub unsafe fn lua_error(L: *mut lua_State) -> ! { + lua_error_(L); + unreachable!(); +} + +// +// Some useful macros (implemented as Rust functions) +// +#[inline(always)] +pub unsafe fn lua_getextraspace(L: *mut lua_State) -> *mut c_void { + (L as *mut c_char).sub(LUA_EXTRASPACE) as *mut c_void +} + +#[inline(always)] +pub unsafe fn lua_tonumber(L: *mut lua_State, i: c_int) -> lua_Number { + lua_tonumberx(L, i, ptr::null_mut()) +} + +#[inline(always)] +pub unsafe fn lua_tointeger(L: *mut lua_State, i: c_int) -> lua_Integer { + lua_tointegerx(L, i, ptr::null_mut()) +} + +#[inline(always)] +pub unsafe fn lua_pop(L: *mut lua_State, n: c_int) { + lua_settop(L, -n - 1) +} + +#[inline(always)] +pub unsafe fn lua_newtable(L: *mut lua_State) { + lua_createtable(L, 0, 0) +} + +#[inline(always)] +pub unsafe fn lua_register(L: *mut lua_State, n: *const c_char, f: lua_CFunction) { + lua_pushcfunction(L, f); + lua_setglobal(L, n) +} + +#[inline(always)] +pub unsafe fn lua_pushcfunction(L: *mut lua_State, f: lua_CFunction) { + lua_pushcclosure(L, f, 0) +} + +#[inline(always)] +pub unsafe fn lua_isfunction(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TFUNCTION) as c_int +} + +#[inline(always)] +pub unsafe fn lua_istable(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TTABLE) as c_int +} + +#[inline(always)] +pub unsafe fn lua_islightuserdata(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TLIGHTUSERDATA) as c_int +} + +#[inline(always)] +pub unsafe fn lua_isnil(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TNIL) as c_int +} + +#[inline(always)] +pub unsafe fn lua_isboolean(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TBOOLEAN) as c_int +} + +#[inline(always)] +pub unsafe fn lua_isthread(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TTHREAD) as c_int +} + +#[inline(always)] +pub unsafe fn lua_isnone(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TNONE) as c_int +} + +#[inline(always)] +pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) <= 0) as c_int +} + +#[inline(always)] +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring(L, s.as_ptr()); +} + +#[inline(always)] +pub unsafe fn lua_pushglobaltable(L: *mut lua_State) -> c_int { + lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS) +} + +#[inline(always)] +pub unsafe fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void { + if lua_islightuserdata(L, idx) != 0 { + return lua_touserdata(L, idx); + } + ptr::null_mut() +} + +#[inline(always)] +pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { + lua_tolstring(L, i, ptr::null_mut()) +} + +#[inline(always)] +pub unsafe fn lua_insert(L: *mut lua_State, idx: c_int) { + lua_rotate(L, idx, 1) +} + +#[inline(always)] +pub unsafe fn lua_remove(L: *mut lua_State, idx: c_int) { + lua_rotate(L, idx, -1); + lua_pop(L, 1) +} + +#[inline(always)] +pub unsafe fn lua_replace(L: *mut lua_State, idx: c_int) { + lua_copy(L, -1, idx); + lua_pop(L, 1) +} + +#[inline(always)] +pub unsafe fn lua_xpush(from: *mut lua_State, to: *mut lua_State, idx: c_int) { + lua_pushvalue(from, idx); + lua_xmove(from, to, 1); +} + +#[inline(always)] +pub unsafe fn lua_newuserdata(L: *mut lua_State, sz: usize) -> *mut c_void { + lua_newuserdatauv(L, sz, 1) +} + +#[inline(always)] +pub unsafe fn lua_getuservalue(L: *mut lua_State, idx: c_int) -> c_int { + lua_getiuservalue(L, idx, 1) +} + +#[inline(always)] +pub unsafe fn lua_setuservalue(L: *mut lua_State, idx: c_int) -> c_int { + lua_setiuservalue(L, idx, 1) +} + +// +// Debug API +// + +// Maximum size for the description of the source of a function in debug information. +const LUA_IDSIZE: usize = 60; + +// Event codes +pub const LUA_HOOKCALL: c_int = 0; +pub const LUA_HOOKRET: c_int = 1; +pub const LUA_HOOKLINE: c_int = 2; +pub const LUA_HOOKCOUNT: c_int = 3; +pub const LUA_HOOKTAILCALL: c_int = 4; + +// Event masks +pub const LUA_MASKCALL: c_int = 1 << (LUA_HOOKCALL as usize); +pub const LUA_MASKRET: c_int = 1 << (LUA_HOOKRET as usize); +pub const LUA_MASKLINE: c_int = 1 << (LUA_HOOKLINE as usize); +pub const LUA_MASKCOUNT: c_int = 1 << (LUA_HOOKCOUNT as usize); + +/// Type for functions to be called on debug events. +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn lua_getstack(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) -> c_int; + pub fn lua_getinfo(L: *mut lua_State, what: *const c_char, ar: *mut lua_Debug) -> c_int; + pub fn lua_getlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; + pub fn lua_setlocal(L: *mut lua_State, ar: *const lua_Debug, n: c_int) -> *const c_char; + pub fn lua_getupvalue(L: *mut lua_State, funcindex: c_int, n: c_int) -> *const c_char; + pub fn lua_setupvalue(L: *mut lua_State, funcindex: c_int, n: c_int) -> *const c_char; + + pub fn lua_upvalueid(L: *mut lua_State, fidx: c_int, n: c_int) -> *mut c_void; + pub fn lua_upvaluejoin(L: *mut lua_State, fidx1: c_int, n1: c_int, fidx2: c_int, n2: c_int); + + pub fn lua_sethook(L: *mut lua_State, func: Option, mask: c_int, count: c_int); + pub fn lua_gethook(L: *mut lua_State) -> Option; + pub fn lua_gethookmask(L: *mut lua_State) -> c_int; + pub fn lua_gethookcount(L: *mut lua_State) -> c_int; +} + +#[repr(C)] +pub struct lua_Debug { + pub event: c_int, + pub name: *const c_char, // (n) + pub namewhat: *const c_char, // (n) 'global', 'local', 'field', 'method' + pub what: *const c_char, // (S) 'Lua', 'C', 'main', 'tail' + pub source: *const c_char, // (S) + pub srclen: usize, // (S) + pub currentline: c_int, // (l) + pub linedefined: c_int, // (S) + pub lastlinedefined: c_int, // (S) + pub nups: c_uchar, // (u) number of upvalues + pub nparams: c_uchar, // (u) number of parameters + pub isvararg: c_char, // (u) + pub extraargs: c_uchar, // (t) number of extra arguments + pub istailcall: c_char, // (t) + pub ftransfer: c_int, // (r) index of first value transferred + pub ntransfer: c_int, // (r) number of transferred values + pub short_src: [c_char; LUA_IDSIZE], // (S) + // lua.h mentions this is for private use + i_ci: *mut c_void, +} diff --git a/mlua-sys/src/lua55/lualib.rs b/mlua-sys/src/lua55/lualib.rs new file mode 100644 index 00000000..33903add --- /dev/null +++ b/mlua-sys/src/lua55/lualib.rs @@ -0,0 +1,55 @@ +//! Contains definitions from `lualib.h`. + +use std::os::raw::{c_char, c_int}; + +use super::lua::lua_State; + +pub const LUA_GLIBK: c_int = 1; + +pub const LUA_LOADLIBNAME: *const c_char = cstr!("package"); +pub const LUA_LOADLIBK: c_int = LUA_GLIBK << 1; + +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_COLIBK: c_int = LUA_GLIBK << 2; + +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_DBLIBK: c_int = LUA_GLIBK << 3; + +pub const LUA_IOLIBNAME: *const c_char = cstr!("io"); +pub const LUA_IOLIBK: c_int = LUA_GLIBK << 4; + +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_MATHLIBK: c_int = LUA_GLIBK << 5; + +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_OSLIBK: c_int = LUA_GLIBK << 6; + +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_STRLIBK: c_int = LUA_GLIBK << 7; + +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_TABLIBK: c_int = LUA_GLIBK << 8; + +pub const LUA_UTF8LIBNAME: *const c_char = cstr!("utf8"); +pub const LUA_UTF8LIBK: c_int = LUA_GLIBK << 9; + +#[cfg_attr(all(windows, raw_dylib), link(name = "lua55", kind = "raw-dylib"))] +unsafe extern "C-unwind" { + pub fn luaopen_base(L: *mut lua_State) -> c_int; + pub fn luaopen_package(L: *mut lua_State) -> c_int; + pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; + pub fn luaopen_debug(L: *mut lua_State) -> c_int; + pub fn luaopen_io(L: *mut lua_State) -> c_int; + pub fn luaopen_math(L: *mut lua_State) -> c_int; + pub fn luaopen_os(L: *mut lua_State) -> c_int; + pub fn luaopen_string(L: *mut lua_State) -> c_int; + pub fn luaopen_table(L: *mut lua_State) -> c_int; + pub fn luaopen_utf8(L: *mut lua_State) -> c_int; + + // open all builtin libraries + pub fn luaL_openselectedlibs(L: *mut lua_State, load: c_int, preload: c_int); +} + +pub unsafe fn luaL_openlibs(L: *mut lua_State) { + luaL_openselectedlibs(L, !0, 0); +} diff --git a/mlua-sys/src/lua55/mod.rs b/mlua-sys/src/lua55/mod.rs new file mode 100644 index 00000000..8b653a4c --- /dev/null +++ b/mlua-sys/src/lua55/mod.rs @@ -0,0 +1,9 @@ +//! Low level bindings to Lua 5.5. + +pub use lauxlib::*; +pub use lua::*; +pub use lualib::*; + +pub mod lauxlib; +pub mod lua; +pub mod lualib; diff --git a/src/ffi/luau/compat.rs b/mlua-sys/src/luau/compat.rs similarity index 61% rename from src/ffi/luau/compat.rs rename to mlua-sys/src/luau/compat.rs index eae46126..bf2a7b95 100644 --- a/src/ffi/luau/compat.rs +++ b/mlua-sys/src/luau/compat.rs @@ -1,16 +1,17 @@ -//! MLua compatibility layer for Roblox Luau. +//! MLua compatibility layer for Luau. //! //! Based on github.com/keplerproject/lua-compat-5.3 use std::ffi::CStr; -use std::mem; use std::os::raw::{c_char, c_int, c_void}; -use std::ptr; +use std::{mem, ptr}; use super::lauxlib::*; use super::lua::*; use super::luacode::*; +pub const LUA_RESUMEERROR: c_int = -1; + unsafe fn compat53_reverse(L: *mut lua_State, mut a: c_int, mut b: c_int) { while a < b { lua_pushvalue(L, a); @@ -22,8 +23,8 @@ unsafe fn compat53_reverse(L: *mut lua_State, mut a: c_int, mut b: c_int) { } } -const COMPAT53_LEVELS1: c_int = 12; // size of the first part of the stack -const COMPAT53_LEVELS2: c_int = 10; // size of the second part of the stack +const COMPAT53_LEVELS1: c_int = 10; // size of the first part of the stack +const COMPAT53_LEVELS2: c_int = 11; // size of the second part of the stack unsafe fn compat53_findfield(L: *mut lua_State, objidx: c_int, level: c_int) -> c_int { if level == 0 || lua_istable(L, -1) == 0 { @@ -40,11 +41,10 @@ unsafe fn compat53_findfield(L: *mut lua_State, objidx: c_int, level: c_int) -> lua_pop(L, 1); // remove value (but keep name) return 1; } else if compat53_findfield(L, objidx, level - 1) != 0 { - // try recursively - lua_remove(L, -2); // remove table (but keep name) - lua_pushliteral(L, "."); - lua_insert(L, -2); // place '.' between the two names - lua_concat(L, 3); + // stack: lib_name, lib_table, field_name (top) + lua_pushliteral(L, c"."); // place '.' between the two names + lua_replace(L, -3); // (in the slot occupied by table) + lua_concat(L, 3); // lib_name.field_name return 1; } } @@ -55,16 +55,23 @@ unsafe fn compat53_findfield(L: *mut lua_State, objidx: c_int, level: c_int) -> unsafe fn compat53_pushglobalfuncname( L: *mut lua_State, + L1: *mut lua_State, level: c_int, ar: *mut lua_Debug, ) -> c_int { let top = lua_gettop(L); - // push function - lua_getinfo(L, level, cstr!("f"), ar); + lua_getinfo(L1, level, cstr!("f"), ar); // push function + lua_xmove(L1, L, 1); // and move onto L lua_pushvalue(L, LUA_GLOBALSINDEX); + luaL_checkstack(L, 6, cstr!("not enough stack")); // slots for 'findfield' if compat53_findfield(L, top + 1, 2) != 0 { + let name = lua_tostring(L, -1); + if CStr::from_ptr(name).to_bytes().starts_with(b"_G.") { + lua_pushstring(L, name.add(3)); // push name without prefix + lua_remove(L, -2); // remove original name + } lua_copy(L, -1, top + 1); // move name to proper place - lua_pop(L, 2); // remove pushed values + lua_settop(L, top + 1); // remove pushed values 1 } else { lua_settop(L, top); // remove function and global table @@ -72,15 +79,18 @@ unsafe fn compat53_pushglobalfuncname( } } -unsafe fn compat53_pushfuncname(L: *mut lua_State, level: c_int, ar: *mut lua_Debug) { +unsafe fn compat53_pushfuncname(L: *mut lua_State, L1: *mut lua_State, level: c_int, ar: *mut lua_Debug) { if !(*ar).name.is_null() { // is there a name? lua_pushfstring(L, cstr!("function '%s'"), (*ar).name); - } else if compat53_pushglobalfuncname(L, level, ar) != 0 { + } else if compat53_pushglobalfuncname(L, L1, level, ar) != 0 { lua_pushfstring(L, cstr!("function '%s'"), lua_tostring(L, -1)); lua_remove(L, -2); // remove name + } else if *(*ar).what != b'C' as c_char { + // for Lua functions, use + lua_pushfstring(L, cstr!("function <%s:%d>"), (*ar).short_src, (*ar).linedefined); } else { - lua_pushliteral(L, "?"); + lua_pushliteral(L, c"?"); } } @@ -113,7 +123,7 @@ pub unsafe fn lua_rotate(L: *mut lua_State, mut idx: c_int, mut n: c_int) { #[inline(always)] pub unsafe fn lua_copy(L: *mut lua_State, fromidx: c_int, toidx: c_int) { let abs_to = lua_absindex(L, toidx); - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); lua_pushvalue(L, fromidx); lua_replace(L, abs_to); } @@ -123,13 +133,19 @@ pub unsafe fn lua_isinteger(L: *mut lua_State, idx: c_int) -> c_int { if lua_type(L, idx) == LUA_TNUMBER { let n = lua_tonumber(L, idx); let i = lua_tointeger(L, idx); - if (n - i as lua_Number).abs() < lua_Number::EPSILON { + // Lua 5.3+ returns "false" for `-0.0` + if n.to_bits() == (i as lua_Number).to_bits() { return 1; } } 0 } +#[inline(always)] +pub unsafe fn lua_pushinteger(L: *mut lua_State, i: lua_Integer) { + lua_pushnumber(L, i as lua_Number); +} + #[inline(always)] pub unsafe fn lua_tointeger(L: *mut lua_State, i: c_int) -> lua_Integer { lua_tointegerx(L, i, ptr::null_mut()) @@ -181,21 +197,20 @@ pub unsafe fn lua_geti(L: *mut lua_State, mut idx: c_int, n: lua_Integer) -> c_i #[inline(always)] pub unsafe fn lua_rawgeti(L: *mut lua_State, idx: c_int, n: lua_Integer) -> c_int { + let n = n.try_into().expect("cannot convert index from lua_Integer"); lua_rawgeti_(L, idx, n) } #[inline(always)] pub unsafe fn lua_rawgetp(L: *mut lua_State, idx: c_int, p: *const c_void) -> c_int { - let abs_i = lua_absindex(L, idx); - lua_pushlightuserdata(L, p as *mut c_void); - lua_rawget(L, abs_i) + lua_rawgetptagged(L, idx, p, 0) } #[inline(always)] pub unsafe fn lua_getuservalue(L: *mut lua_State, mut idx: c_int) -> c_int { luaL_checkstack(L, 2, cstr!("not enough stack slots available")); idx = lua_absindex(L, idx); - lua_pushliteral(L, "__mlua_uservalues"); + lua_pushliteral(L, c"__mlua_uservalues"); if lua_rawget(L, LUA_REGISTRYINDEX) != LUA_TTABLE { return LUA_TNIL; } @@ -216,29 +231,26 @@ pub unsafe fn lua_seti(L: *mut lua_State, mut idx: c_int, n: lua_Integer) { #[inline(always)] pub unsafe fn lua_rawseti(L: *mut lua_State, idx: c_int, n: lua_Integer) { + let n = n.try_into().expect("cannot convert index from lua_Integer"); lua_rawseti_(L, idx, n) } #[inline(always)] pub unsafe fn lua_rawsetp(L: *mut lua_State, idx: c_int, p: *const c_void) { - let abs_i = lua_absindex(L, idx); - luaL_checkstack(L, 1, cstr!("not enough stack slots")); - lua_pushlightuserdata(L, p as *mut c_void); - lua_insert(L, -2); - lua_rawset(L, abs_i); + lua_rawsetptagged(L, idx, p, 0) } #[inline(always)] pub unsafe fn lua_setuservalue(L: *mut lua_State, mut idx: c_int) { luaL_checkstack(L, 4, cstr!("not enough stack slots available")); idx = lua_absindex(L, idx); - lua_pushliteral(L, "__mlua_uservalues"); + lua_pushliteral(L, c"__mlua_uservalues"); lua_pushvalue(L, -1); if lua_rawget(L, LUA_REGISTRYINDEX) != LUA_TTABLE { lua_pop(L, 1); lua_createtable(L, 0, 2); // main table lua_createtable(L, 0, 1); // metatable - lua_pushliteral(L, "k"); + lua_pushliteral(L, c"k"); lua_setfield(L, -2, cstr!("__mode")); lua_setmetatable(L, -2); lua_pushvalue(L, -2); @@ -281,12 +293,7 @@ pub unsafe fn lua_pushglobaltable(L: *mut lua_State) { } #[inline(always)] -pub unsafe fn lua_resume( - L: *mut lua_State, - from: *mut lua_State, - narg: c_int, - nres: *mut c_int, -) -> c_int { +pub unsafe fn lua_resume(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int { let ret = lua_resume_(L, from, narg); if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { *nres = lua_gettop(L); @@ -294,6 +301,19 @@ pub unsafe fn lua_resume( ret } +#[inline(always)] +pub unsafe fn lua_resumex(L: *mut lua_State, from: *mut lua_State, narg: c_int, nres: *mut c_int) -> c_int { + let ret = if narg == LUA_RESUMEERROR { + lua_resumeerror(L, from) + } else { + lua_resume_(L, from, narg) + }; + if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { + *nres = lua_gettop(L); + } + ret +} + // // lauxlib ported functions // @@ -304,12 +324,30 @@ pub unsafe fn luaL_checkstack(L: *mut lua_State, sz: c_int, msg: *const c_char) if !msg.is_null() { luaL_error(L, cstr!("stack overflow (%s)"), msg); } else { - lua_pushliteral(L, "stack overflow"); + lua_pushliteral(L, c"stack overflow"); lua_error(L); } } } +#[inline(always)] +pub unsafe fn luaL_checkinteger(L: *mut lua_State, narg: c_int) -> lua_Integer { + let mut isnum = 0; + let int = lua_tointegerx(L, narg, &mut isnum); + if isnum == 0 { + luaL_typeerror(L, narg, lua_typename(L, LUA_TNUMBER)); + } + int +} + +pub unsafe fn luaL_optinteger(L: *mut lua_State, narg: c_int, def: lua_Integer) -> lua_Integer { + if lua_isnoneornil(L, narg) != 0 { + def + } else { + luaL_checkinteger(L, narg) + } +} + #[inline(always)] pub unsafe fn luaL_getmetafield(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int { if luaL_getmetafield_(L, obj, e) != 0 { @@ -323,57 +361,79 @@ pub unsafe fn luaL_getmetafield(L: *mut lua_State, obj: c_int, e: *const c_char) pub unsafe fn luaL_newmetatable(L: *mut lua_State, tname: *const c_char) -> c_int { if luaL_newmetatable_(L, tname) != 0 { lua_pushstring(L, tname); - lua_setfield(L, -2, cstr!("__name")); + lua_setfield(L, -2, cstr!("__type")); 1 } else { 0 } } -pub unsafe fn luaL_loadbufferx( +pub unsafe fn luaL_loadbufferenv( L: *mut lua_State, data: *const c_char, mut size: usize, name: *const c_char, mode: *const c_char, + mut env: c_int, ) -> c_int { - extern "C" { + unsafe extern "C" { fn free(p: *mut c_void); } - let chunk_is_text = (*data as u8) >= b'\n'; + unsafe extern "C" fn data_dtor(_: *mut lua_State, data: *mut c_void) { + free(*(data as *mut *mut c_char) as *mut c_void); + } + + let chunk_is_text = size == 0 || (*data as u8) >= b'\t'; if !mode.is_null() { let modeb = CStr::from_ptr(mode).to_bytes(); if !chunk_is_text && !modeb.contains(&b'b') { - lua_pushfstring( - L, - cstr!("attempt to load a binary chunk (mode is '%s')"), - mode, - ); + lua_pushfstring(L, cstr!("attempt to load a binary chunk (mode is '%s')"), mode); return LUA_ERRSYNTAX; } else if chunk_is_text && !modeb.contains(&b't') { - lua_pushfstring( - L, - cstr!("attempt to load a text chunk (mode is '%s')"), - mode, - ); + lua_pushfstring(L, cstr!("attempt to load a text chunk (mode is '%s')"), mode); return LUA_ERRSYNTAX; } } - if chunk_is_text { + let status = if chunk_is_text { + if env < 0 { + env -= 1; + } + let data_ud = lua_newuserdatadtor(L, mem::size_of::<*mut c_char>(), data_dtor) as *mut *mut c_char; let data = luau_compile_(data, size, ptr::null_mut(), &mut size); - let ok = luau_load(L, name, data, size, 0) == 0; - free(data as *mut c_void); - if !ok { - return LUA_ERRSYNTAX; + ptr::write(data_ud, data); + // By deferring the `free(data)` to the userdata destructor, we ensure that + // even if `luau_load` throws an error, the `data` is still released. + let status = luau_load(L, name, data, size, env); + lua_replace(L, -2); // replace data with the result + status + } else { + luau_load(L, name, data, size, env) + }; + + if status != 0 { + if lua_isstring(L, -1) != 0 && CStr::from_ptr(lua_tostring(L, -1)) == c"not enough memory" { + // A case for Luau >= 0.679 + return LUA_ERRMEM; } - } else if luau_load(L, name, data, size, 0) != 0 { return LUA_ERRSYNTAX; } + LUA_OK } +#[inline(always)] +pub unsafe fn luaL_loadbufferx( + L: *mut lua_State, + data: *const c_char, + size: usize, + name: *const c_char, + mode: *const c_char, +) -> c_int { + luaL_loadbufferenv(L, data, size, name, mode, 0) +} + #[inline(always)] pub unsafe fn luaL_loadbuffer( L: *mut lua_State, @@ -381,13 +441,13 @@ pub unsafe fn luaL_loadbuffer( size: usize, name: *const c_char, ) -> c_int { - luaL_loadbufferx(L, data, size, name, ptr::null()) + luaL_loadbufferenv(L, data, size, name, ptr::null(), 0) } #[inline(always)] pub unsafe fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer { let mut isnum = 0; - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); lua_len(L, idx); let res = lua_tointegerx(L, -1, &mut isnum); lua_pop(L, 1); @@ -397,64 +457,65 @@ pub unsafe fn luaL_len(L: *mut lua_State, idx: c_int) -> lua_Integer { res } -pub unsafe fn luaL_traceback( - L: *mut lua_State, - L1: *mut lua_State, - msg: *const c_char, - mut level: c_int, -) { +pub unsafe fn luaL_traceback(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, mut level: c_int) { let mut ar: lua_Debug = mem::zeroed(); - let top = lua_gettop(L); let numlevels = lua_stackdepth(L); - let mark = if numlevels > COMPAT53_LEVELS1 + COMPAT53_LEVELS2 { - COMPAT53_LEVELS1 - } else { - 0 - }; + #[rustfmt::skip] + let mut limit = if numlevels - level > COMPAT53_LEVELS1 + COMPAT53_LEVELS2 { COMPAT53_LEVELS1 } else { -1 }; + + let mut buf: luaL_Strbuf = mem::zeroed(); + luaL_buffinit(L, &mut buf); if !msg.is_null() { - lua_pushfstring(L, cstr!("%s\n"), msg); + luaL_addstring(&mut buf, msg); + luaL_addstring(&mut buf, cstr!("\n")); } - lua_pushliteral(L, "stack traceback:"); - while lua_getinfo(L1, level, cstr!(""), &mut ar) != 0 { - if level + 1 == mark { + luaL_addstring(&mut buf, cstr!("stack traceback:")); + while lua_getinfo(L1, level, cstr!("sln"), &mut ar) != 0 { + if limit == 0 { // too many levels? - lua_pushliteral(L, "\n\t..."); // add a '...' - level = numlevels - COMPAT53_LEVELS2; // and skip to last ones + let n = numlevels - level - COMPAT53_LEVELS2; + // add warning about skip ("n + 1" because we skip current level too) + lua_pushfstring(L, cstr!("\n\t...\t(skipping %d levels)"), n + 1); + luaL_addvalue(&mut buf); + level += n; // and skip to last levels } else { - lua_getinfo(L1, level, cstr!("sln"), &mut ar); - lua_pushfstring(L, cstr!("\n\t%s:"), ar.short_src.as_ptr()); + luaL_addstring(&mut buf, cstr!("\n\t")); + luaL_addstring(&mut buf, ar.short_src); + luaL_addstring(&mut buf, cstr!(":")); if ar.currentline > 0 { - lua_pushfstring(L, cstr!("%d:"), ar.currentline); + luaL_addunsigned(&mut buf, ar.currentline as _); + luaL_addstring(&mut buf, cstr!(":")); } - lua_pushliteral(L, " in "); - compat53_pushfuncname(L, level, &mut ar); - lua_concat(L, lua_gettop(L) - top); + luaL_addstring(&mut buf, cstr!(" in ")); + compat53_pushfuncname(L, L1, level, &mut ar); + luaL_addvalue(&mut buf); } level += 1; + limit -= 1; } - lua_concat(L, lua_gettop(L) - top); + luaL_pushresult(&mut buf); } -pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char { +pub unsafe fn luaL_tolstring(L: *mut lua_State, mut idx: c_int, len: *mut usize) -> *const c_char { + idx = lua_absindex(L, idx); if luaL_callmeta(L, idx, cstr!("__tostring")) == 0 { - let t = lua_type(L, idx); - match t { + match lua_type(L, idx) { LUA_TNIL => { - lua_pushliteral(L, "nil"); + lua_pushliteral(L, c"nil"); } LUA_TSTRING | LUA_TNUMBER => { lua_pushvalue(L, idx); } LUA_TBOOLEAN => { if lua_toboolean(L, idx) == 0 { - lua_pushliteral(L, "false"); + lua_pushliteral(L, c"false"); } else { - lua_pushliteral(L, "true"); + lua_pushliteral(L, c"true"); } } - _ => { - let tt = luaL_getmetafield(L, idx, cstr!("__name")); + t => { + let tt = luaL_getmetafield(L, idx, cstr!("__type")); let name = if tt == LUA_TSTRING { lua_tostring(L, -1) } else { @@ -462,7 +523,7 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> }; lua_pushfstring(L, cstr!("%s: %p"), name, lua_topointer(L, idx)); if tt != LUA_TNIL { - lua_replace(L, -2); + lua_replace(L, -2); // remove '__type' } } }; @@ -474,14 +535,14 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> #[inline(always)] pub unsafe fn luaL_setmetatable(L: *mut lua_State, tname: *const c_char) { - luaL_checkstack(L, 1, cstr!("not enough stack slots")); + luaL_checkstack(L, 1, cstr!("not enough stack slots available")); luaL_getmetatable(L, tname); lua_setmetatable(L, -2); } pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_char) -> c_int { let abs_i = lua_absindex(L, idx); - luaL_checkstack(L, 3, cstr!("not enough stack slots")); + luaL_checkstack(L, 3, cstr!("not enough stack slots available")); lua_pushstring_(L, fname); if lua_gettable(L, abs_i) == LUA_TTABLE { return 1; @@ -494,14 +555,9 @@ pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_ch 0 } -pub unsafe fn luaL_requiref( - L: *mut lua_State, - modname: *const c_char, - openf: lua_CFunction, - glb: c_int, -) { +pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) { luaL_checkstack(L, 3, cstr!("not enough stack slots available")); - luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED")); + luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE); if lua_getfield(L, -1, modname) == LUA_TNIL { lua_pop(L, 1); lua_pushcfunction(L, openf); @@ -513,6 +569,9 @@ pub unsafe fn luaL_requiref( if glb != 0 { lua_pushvalue(L, -1); lua_setglobal(L, modname); + } else { + lua_pushnil(L); + lua_setglobal(L, modname); } lua_replace(L, -2); } diff --git a/src/ffi/luau/lauxlib.rs b/mlua-sys/src/luau/lauxlib.rs similarity index 58% rename from src/ffi/luau/lauxlib.rs rename to mlua-sys/src/luau/lauxlib.rs index af7d1c2a..0328aaf0 100644 --- a/src/ffi/luau/lauxlib.rs +++ b/mlua-sys/src/luau/lauxlib.rs @@ -3,9 +3,10 @@ use std::os::raw::{c_char, c_float, c_int, c_void}; use std::ptr; -use super::lua::{ - self, lua_CFunction, lua_Integer, lua_Number, lua_State, lua_Unsigned, LUA_REGISTRYINDEX, -}; +use super::lua::{self, LUA_REGISTRYINDEX, lua_CFunction, lua_Number, lua_State, lua_Unsigned}; + +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); #[repr(C)] pub struct luaL_Reg { @@ -13,7 +14,7 @@ pub struct luaL_Reg { pub func: lua_CFunction, } -extern "C" { +unsafe extern "C-unwind" { pub fn luaL_register(L: *mut lua_State, libname: *const c_char, l: *const luaL_Reg); #[link_name = "luaL_getmetafield"] pub fn luaL_getmetafield_(L: *mut lua_State, obj: c_int, e: *const c_char) -> c_int; @@ -35,8 +36,12 @@ extern "C" { pub fn luaL_checkboolean(L: *mut lua_State, narg: c_int) -> c_int; pub fn luaL_optboolean(L: *mut lua_State, narg: c_int, def: c_int) -> c_int; - pub fn luaL_checkinteger(L: *mut lua_State, narg: c_int) -> lua_Integer; - pub fn luaL_optinteger(L: *mut lua_State, narg: c_int, def: lua_Integer) -> lua_Integer; + #[link_name = "luaL_checkinteger"] + pub fn luaL_checkinteger_(L: *mut lua_State, narg: c_int) -> c_int; + pub fn luaL_checkinteger64(L: *mut lua_State, narg: c_int) -> i64; + #[link_name = "luaL_optinteger"] + pub fn luaL_optinteger_(L: *mut lua_State, narg: c_int, def: c_int) -> c_int; + pub fn luaL_optinteger64(L: *mut lua_State, narg: c_int, def: i64) -> i64; pub fn luaL_checkunsigned(L: *mut lua_State, narg: c_int) -> lua_Unsigned; pub fn luaL_optunsigned(L: *mut lua_State, narg: c_int, def: lua_Unsigned) -> lua_Unsigned; @@ -52,6 +57,8 @@ extern "C" { pub fn luaL_newmetatable_(L: *mut lua_State, tname: *const c_char) -> c_int; pub fn luaL_checkudata(L: *mut lua_State, ud: c_int, tname: *const c_char) -> *mut c_void; + pub fn luaL_checkbuffer(L: *mut lua_State, narg: c_int, len: *mut usize) -> *mut c_void; + pub fn luaL_where(L: *mut lua_State, lvl: c_int); #[link_name = "luaL_errorL"] @@ -69,10 +76,20 @@ extern "C" { pub fn luaL_newstate() -> *mut lua_State; - // TODO: luaL_findtable + pub fn luaL_findtable( + L: *mut lua_State, + idx: c_int, + fname: *const c_char, + szhint: c_int, + ) -> *const c_char; pub fn luaL_typename(L: *mut lua_State, idx: c_int) -> *const c_char; + pub fn luaL_callyieldable(L: *mut lua_State, nargs: c_int, nresults: c_int) -> c_int; + + #[link_name = "luaL_traceback"] + pub fn luaL_traceback_(L: *mut lua_State, L1: *mut lua_State, msg: *const c_char, level: c_int); + // sandbox libraries and globals #[link_name = "luaL_sandbox"] pub fn luaL_sandbox_(L: *mut lua_State); @@ -107,7 +124,19 @@ pub unsafe fn luaL_optstring(L: *mut lua_State, n: c_int, d: *const c_char) -> * luaL_optlstring(L, n, d, ptr::null_mut()) } -// TODO: luaL_opt +#[inline(always)] +pub unsafe fn luaL_opt( + L: *mut lua_State, + f: unsafe extern "C-unwind" fn(*mut lua_State, c_int) -> T, + n: c_int, + d: T, +) -> T { + if lua::lua_isnoneornil(L, n) != 0 { + d + } else { + f(L, n) + } +} #[inline(always)] pub unsafe fn luaL_getmetatable(L: *mut lua_State, n: *const c_char) -> c_int { @@ -141,10 +170,13 @@ pub unsafe fn luaL_sandbox(L: *mut lua_State, enabled: c_int) { } // set all builtin metatables to read-only - lua_pushliteral(L, ""); - lua_getmetatable(L, -1); - lua_setreadonly(L, -1, enabled); - lua_pop(L, 2); + lua_pushliteral(L, c""); + if lua_getmetatable(L, -1) != 0 { + lua_setreadonly(L, -1, enabled); + lua_pop(L, 2); + } else { + lua_pop(L, 1); + } // set globals to readonly and activate safeenv since the env is immutable lua_setreadonly(L, LUA_GLOBALSINDEX, enabled); @@ -152,5 +184,63 @@ pub unsafe fn luaL_sandbox(L: *mut lua_State, enabled: c_int) { } // -// TODO: Generic Buffer Manipulation +// Generic Buffer Manipulation // + +/// Buffer size used for on-stack string operations. This limit depends on native stack size. +pub const LUA_BUFFERSIZE: usize = 512; + +#[repr(C)] +pub struct luaL_Strbuf { + p: *mut c_char, // current position in buffer + end: *mut c_char, // end of the current buffer + L: *mut lua_State, + storage: *mut c_void, // TString + buffer: [c_char; LUA_BUFFERSIZE], +} + +// For compatibility +pub type luaL_Buffer = luaL_Strbuf; + +unsafe extern "C-unwind" { + pub fn luaL_buffinit(L: *mut lua_State, B: *mut luaL_Strbuf); + pub fn luaL_buffinitsize(L: *mut lua_State, B: *mut luaL_Strbuf, size: usize) -> *mut c_char; + pub fn luaL_prepbuffsize(B: *mut luaL_Strbuf, size: usize) -> *mut c_char; + pub fn luaL_addlstring(B: *mut luaL_Strbuf, s: *const c_char, l: usize); + pub fn luaL_addvalue(B: *mut luaL_Strbuf); + pub fn luaL_addvalueany(B: *mut luaL_Strbuf, idx: c_int); + pub fn luaL_pushresult(B: *mut luaL_Strbuf); + pub fn luaL_pushresultsize(B: *mut luaL_Strbuf, size: usize); +} + +pub unsafe fn luaL_addchar(B: *mut luaL_Strbuf, c: c_char) { + if (*B).p >= (*B).end { + luaL_prepbuffsize(B, 1); + } + *(*B).p = c; + (*B).p = (*B).p.add(1); +} + +pub unsafe fn luaL_addstring(B: *mut luaL_Strbuf, s: *const c_char) { + // Calculate length of s + let mut len = 0; + while *s.add(len) != 0 { + len += 1; + } + luaL_addlstring(B, s, len); +} + +pub unsafe fn luaL_addunsigned(B: *mut luaL_Strbuf, mut n: lua_Unsigned) { + let mut buf: [c_char; 32] = [0; 32]; + let mut i = 32; + loop { + i -= 1; + let digit = (n % 10) as u8; + buf[i] = (b'0' + digit) as c_char; + n /= 10; + if n == 0 { + break; + } + } + luaL_addlstring(B, buf.as_ptr().add(i), 32 - i); +} diff --git a/src/ffi/luau/lua.rs b/mlua-sys/src/luau/lua.rs similarity index 62% rename from src/ffi/luau/lua.rs rename to mlua-sys/src/luau/lua.rs index e6c164f1..7fc6b19b 100644 --- a/src/ffi/luau/lua.rs +++ b/mlua-sys/src/luau/lua.rs @@ -1,18 +1,28 @@ //! Contains definitions from `lua.h`. +use std::ffi::CStr; use std::marker::{PhantomData, PhantomPinned}; use std::os::raw::{c_char, c_double, c_float, c_int, c_uint, c_void}; -use std::ptr; +use std::{mem, ptr}; // Option for multiple returns in 'lua_pcall' and 'lua_call' pub const LUA_MULTRET: c_int = -1; +// Max number of Lua stack slots +const LUAI_MAXCSTACK: c_int = 1000000; + +// Number of valid Lua userdata tags +pub const LUA_UTAG_LIMIT: c_int = 128; + +// Number of valid Lua lightuserdata tags +pub const LUA_LUTAG_LIMIT: c_int = 128; + // // Pseudo-indices // -pub const LUA_REGISTRYINDEX: c_int = -10000; -pub const LUA_ENVIRONINDEX: c_int = -10001; -pub const LUA_GLOBALSINDEX: c_int = -10002; +pub const LUA_REGISTRYINDEX: c_int = -LUAI_MAXCSTACK - 2000; +pub const LUA_ENVIRONINDEX: c_int = -LUAI_MAXCSTACK - 2001; +pub const LUA_GLOBALSINDEX: c_int = -LUAI_MAXCSTACK - 2002; pub const fn lua_upvalueindex(i: c_int) -> c_int { LUA_GLOBALSINDEX - i @@ -27,6 +37,16 @@ pub const LUA_ERRRUN: c_int = 2; pub const LUA_ERRSYNTAX: c_int = 3; pub const LUA_ERRMEM: c_int = 4; pub const LUA_ERRERR: c_int = 5; +pub const LUA_BREAK: c_int = 6; // yielded for a debug breakpoint + +// +// Coroutine status +// +pub const LUA_CORUN: c_int = 0; // running +pub const LUA_COSUS: c_int = 1; // suspended +pub const LUA_CONOR: c_int = 2; // 'normal' (it resumed another coroutine) +pub const LUA_COFIN: c_int = 3; // finished +pub const LUA_COERR: c_int = 4; // finished with error /// A raw Lua state associated with a thread. #[repr(C)] @@ -45,13 +65,15 @@ pub const LUA_TBOOLEAN: c_int = 1; pub const LUA_TLIGHTUSERDATA: c_int = 2; pub const LUA_TNUMBER: c_int = 3; -pub const LUA_TVECTOR: c_int = 4; +pub const LUA_TINTEGER: c_int = 4; +pub const LUA_TVECTOR: c_int = 5; -pub const LUA_TSTRING: c_int = 5; -pub const LUA_TTABLE: c_int = 6; -pub const LUA_TFUNCTION: c_int = 7; -pub const LUA_TUSERDATA: c_int = 8; -pub const LUA_TTHREAD: c_int = 9; +pub const LUA_TSTRING: c_int = 6; +pub const LUA_TTABLE: c_int = 7; +pub const LUA_TFUNCTION: c_int = 8; +pub const LUA_TUSERDATA: c_int = 9; +pub const LUA_TTHREAD: c_int = 10; +pub const LUA_TBUFFER: c_int = 11; /// Guaranteed number of Lua stack slots available to a C function. pub const LUA_MINSTACK: c_int = 20; @@ -59,28 +81,32 @@ pub const LUA_MINSTACK: c_int = 20; /// A Lua number, usually equivalent to `f64`. pub type lua_Number = c_double; -/// A Lua integer, equivalent to `i32`. -pub type lua_Integer = c_int; +/// A Lua integer, usually equivalent to `i64` +#[cfg(target_pointer_width = "32")] +pub type lua_Integer = i32; +#[cfg(target_pointer_width = "64")] +pub type lua_Integer = i64; /// A Lua unsigned integer, equivalent to `u32`. pub type lua_Unsigned = c_uint; /// Type for native C functions that can be passed to Lua. -pub type lua_CFunction = unsafe extern "C" fn(L: *mut lua_State) -> c_int; -pub type lua_Continuation = unsafe extern "C" fn(L: *mut lua_State, status: c_int) -> c_int; +pub type lua_CFunction = unsafe extern "C-unwind" fn(L: *mut lua_State) -> c_int; +pub type lua_Continuation = unsafe extern "C-unwind" fn(L: *mut lua_State, status: c_int) -> c_int; -/// Type for userdata destructor functions. -pub type lua_Udestructor = unsafe extern "C" fn(*mut c_void); +/// Type for userdata destructor functions (no unwinding). +pub type lua_Destructor = unsafe extern "C" fn(L: *mut lua_State, *mut c_void); -/// Type for memory-allocation functions. -pub type lua_Alloc = unsafe extern "C" fn( - ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, -) -> *mut c_void; +/// Type for memory-allocation functions (no unwinding). +pub type lua_Alloc = + unsafe extern "C" fn(ud: *mut c_void, ptr: *mut c_void, osize: usize, nsize: usize) -> *mut c_void; -extern "C" { +/// Returns Luau release version (eg. `0.xxx`). +pub const fn luau_version() -> Option<&'static str> { + option_env!("LUAU_VERSION") +} + +unsafe extern "C-unwind" { // // State manipulation // @@ -124,19 +150,31 @@ extern "C" { pub fn lua_tonumberx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Number; #[link_name = "lua_tointegerx"] - pub fn lua_tointegerx_(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Integer; + pub fn lua_tointegerx_(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> c_int; pub fn lua_tounsignedx(L: *mut lua_State, idx: c_int, isnum: *mut c_int) -> lua_Unsigned; pub fn lua_tovector(L: *mut lua_State, idx: c_int) -> *const c_float; pub fn lua_toboolean(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_tointeger64(L: *mut lua_State, idx: c_int, isinteger: *mut c_int) -> i64; pub fn lua_tolstring(L: *mut lua_State, idx: c_int, len: *mut usize) -> *const c_char; pub fn lua_tostringatom(L: *mut lua_State, idx: c_int, atom: *mut c_int) -> *const c_char; + pub fn lua_tolstringatom( + L: *mut lua_State, + idx: c_int, + len: *mut usize, + atom: *mut c_int, + ) -> *const c_char; pub fn lua_namecallatom(L: *mut lua_State, atom: *mut c_int) -> *const c_char; - pub fn lua_objlen(L: *mut lua_State, idx: c_int) -> usize; + #[link_name = "lua_objlen"] + pub fn lua_objlen_(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_tocfunction(L: *mut lua_State, idx: c_int) -> Option; + pub fn lua_tolightuserdata(L: *mut lua_State, idx: c_int) -> *mut c_void; + pub fn lua_tolightuserdatatagged(L: *mut lua_State, idx: c_int, tag: c_int) -> *mut c_void; pub fn lua_touserdata(L: *mut lua_State, idx: c_int) -> *mut c_void; pub fn lua_touserdatatagged(L: *mut lua_State, idx: c_int, tag: c_int) -> *mut c_void; pub fn lua_userdatatag(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_lightuserdatatag(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_tothread(L: *mut lua_State, idx: c_int) -> *mut lua_State; + pub fn lua_tobuffer(L: *mut lua_State, idx: c_int, len: *mut usize) -> *mut c_void; pub fn lua_topointer(L: *mut lua_State, idx: c_int) -> *const c_void; // @@ -144,9 +182,14 @@ extern "C" { // pub fn lua_pushnil(L: *mut lua_State); pub fn lua_pushnumber(L: *mut lua_State, n: lua_Number); - pub fn lua_pushinteger(L: *mut lua_State, n: lua_Integer); + #[link_name = "lua_pushinteger"] + pub fn lua_pushinteger_(L: *mut lua_State, n: c_int); + pub fn lua_pushinteger64(L: *mut lua_State, n: i64); pub fn lua_pushunsigned(L: *mut lua_State, n: lua_Unsigned); + #[cfg(not(feature = "luau-vector4"))] pub fn lua_pushvector(L: *mut lua_State, x: c_float, y: c_float, z: c_float); + #[cfg(feature = "luau-vector4")] + pub fn lua_pushvector(L: *mut lua_State, x: c_float, y: c_float, z: c_float, w: c_float); #[link_name = "lua_pushlstring"] pub fn lua_pushlstring_(L: *mut lua_State, s: *const c_char, l: usize); #[link_name = "lua_pushstring"] @@ -164,9 +207,12 @@ extern "C" { pub fn lua_pushboolean(L: *mut lua_State, b: c_int); pub fn lua_pushthread(L: *mut lua_State) -> c_int; - pub fn lua_pushlightuserdata(L: *mut lua_State, p: *mut c_void); + pub fn lua_pushlightuserdatatagged(L: *mut lua_State, p: *mut c_void, tag: c_int); pub fn lua_newuserdatatagged(L: *mut lua_State, sz: usize, tag: c_int) -> *mut c_void; - pub fn lua_newuserdatadtor(L: *mut lua_State, sz: usize, dtor: lua_Udestructor) -> *mut c_void; + pub fn lua_newuserdatataggedwithmetatable(L: *mut lua_State, sz: usize, tag: c_int) -> *mut c_void; + pub fn lua_newuserdatadtor(L: *mut lua_State, sz: usize, dtor: lua_Destructor) -> *mut c_void; + + pub fn lua_newbuffer(L: *mut lua_State, sz: usize) -> *mut c_void; // // Get functions (Lua -> stack) @@ -177,6 +223,7 @@ extern "C" { pub fn lua_rawget(L: *mut lua_State, idx: c_int) -> c_int; #[link_name = "lua_rawgeti"] pub fn lua_rawgeti_(L: *mut lua_State, idx: c_int, n: c_int) -> c_int; + pub fn lua_rawgetptagged(L: *mut lua_State, idx: c_int, p: *const c_void, tag: c_int) -> c_int; pub fn lua_createtable(L: *mut lua_State, narr: c_int, nrec: c_int); pub fn lua_setreadonly(L: *mut lua_State, idx: c_int, enabled: c_int); @@ -191,9 +238,11 @@ extern "C" { // pub fn lua_settable(L: *mut lua_State, idx: c_int); pub fn lua_setfield(L: *mut lua_State, idx: c_int, k: *const c_char); + pub fn lua_rawsetfield(L: *mut lua_State, idx: c_int, k: *const c_char); pub fn lua_rawset(L: *mut lua_State, idx: c_int); #[link_name = "lua_rawseti"] pub fn lua_rawseti_(L: *mut lua_State, idx: c_int, n: c_int); + pub fn lua_rawsetptagged(L: *mut lua_State, idx: c_int, p: *const c_void, tag: c_int); pub fn lua_setmetatable(L: *mut lua_State, objindex: c_int) -> c_int; pub fn lua_setfenv(L: *mut lua_State, idx: c_int) -> c_int; @@ -209,6 +258,7 @@ extern "C" { ) -> c_int; pub fn lua_call(L: *mut lua_State, nargs: c_int, nresults: c_int); pub fn lua_pcall(L: *mut lua_State, nargs: c_int, nresults: c_int, errfunc: c_int) -> c_int; + pub fn lua_cpcall(L: *mut lua_State, f: lua_CFunction, ud: *mut c_void) -> c_int; // // Coroutine functions @@ -222,6 +272,12 @@ extern "C" { pub fn lua_isyieldable(L: *mut lua_State) -> c_int; pub fn lua_getthreaddata(L: *mut lua_State) -> *mut c_void; pub fn lua_setthreaddata(L: *mut lua_State, data: *mut c_void); + pub fn lua_costatus(L: *mut lua_State, co: *mut lua_State) -> c_int; +} + +#[inline(always)] +pub unsafe fn lua_objlen(L: *mut lua_State, idx: c_int) -> usize { + lua_objlen_(L, idx) as usize } // @@ -238,14 +294,14 @@ pub const LUA_GCSETGOAL: c_int = 7; pub const LUA_GCSETSTEPMUL: c_int = 8; pub const LUA_GCSETSTEPSIZE: c_int = 9; -extern "C" { +unsafe extern "C-unwind" { pub fn lua_gc(L: *mut lua_State, what: c_int, data: c_int) -> c_int; } // // Memory statistics // -extern "C" { +unsafe extern "C-unwind" { pub fn lua_setmemcat(L: *mut lua_State, category: c_int); pub fn lua_totalbytes(L: *mut lua_State, category: c_int) -> usize; } @@ -253,18 +309,24 @@ extern "C" { // // Miscellaneous functions // -extern "C" { +unsafe extern "C-unwind" { pub fn lua_error(L: *mut lua_State) -> !; pub fn lua_next(L: *mut lua_State, idx: c_int) -> c_int; + pub fn lua_rawiter(L: *mut lua_State, idx: c_int, iter: c_int) -> c_int; pub fn lua_concat(L: *mut lua_State, n: c_int); - // TODO: lua_encodepointer + pub fn lua_encodepointer(L: *mut lua_State, p: usize) -> usize; pub fn lua_clock() -> c_double; - pub fn lua_setuserdatadtor( - L: *mut lua_State, - tag: c_int, - dtor: Option, - ); + pub fn lua_setuserdatatag(L: *mut lua_State, idx: c_int, tag: c_int); + pub fn lua_setuserdatadtor(L: *mut lua_State, tag: c_int, dtor: Option); + pub fn lua_getuserdatadtor(L: *mut lua_State, tag: c_int) -> Option; + pub fn lua_setuserdatametatable(L: *mut lua_State, tag: c_int); + pub fn lua_getuserdatametatable(L: *mut lua_State, tag: c_int); + pub fn lua_setlightuserdataname(L: *mut lua_State, tag: c_int, name: *const c_char); + pub fn lua_getlightuserdataname(L: *mut lua_State, tag: c_int) -> *const c_char; pub fn lua_clonefunction(L: *mut lua_State, idx: c_int); + pub fn lua_cleartable(L: *mut lua_State, idx: c_int); + pub fn lua_clonetable(L: *mut lua_State, idx: c_int); + pub fn lua_getallocf(L: *mut lua_State, ud: *mut *mut c_void) -> lua_Alloc; } // @@ -273,7 +335,7 @@ extern "C" { pub const LUA_NOREF: c_int = -1; pub const LUA_REFNIL: c_int = 0; -extern "C" { +unsafe extern "C-unwind" { pub fn lua_ref(L: *mut lua_State, idx: c_int) -> c_int; pub fn lua_unref(L: *mut lua_State, r#ref: c_int); } @@ -283,13 +345,13 @@ extern "C" { // #[inline(always)] -pub unsafe fn lua_tonumber(L: *mut lua_State, i: c_int) -> lua_Number { - lua_tonumberx(L, i, ptr::null_mut()) +pub unsafe fn lua_tonumber(L: *mut lua_State, idx: c_int) -> lua_Number { + lua_tonumberx(L, idx, ptr::null_mut()) } #[inline(always)] -pub unsafe fn lua_tointeger_(L: *mut lua_State, i: c_int) -> lua_Integer { - lua_tointegerx_(L, i, ptr::null_mut()) +pub unsafe fn lua_tointeger_(L: *mut lua_State, idx: c_int) -> c_int { + lua_tointegerx_(L, idx, ptr::null_mut()) } #[inline(always)] @@ -312,7 +374,21 @@ pub unsafe fn lua_newuserdata(L: *mut lua_State, sz: usize) -> *mut c_void { lua_newuserdatatagged(L, sz, 0) } -// TODO: lua_strlen +#[inline(always)] +pub unsafe fn lua_newuserdata_t(L: *mut lua_State, data: T) -> *mut T { + unsafe extern "C" fn destructor(_: *mut lua_State, ud: *mut c_void) { + ptr::drop_in_place(ud as *mut T); + } + + let ud_ptr = lua_newuserdatadtor(L, const { mem::size_of::() }, destructor::) as *mut T; + ptr::write(ud_ptr, data); + ud_ptr +} + +#[inline(always)] +pub unsafe fn lua_strlen(L: *mut lua_State, i: c_int) -> usize { + lua_objlen(L, i) +} #[inline(always)] pub unsafe fn lua_isfunction(L: *mut lua_State, n: c_int) -> c_int { @@ -339,6 +415,11 @@ pub unsafe fn lua_isboolean(L: *mut lua_State, n: c_int) -> c_int { (lua_type(L, n) == LUA_TBOOLEAN) as c_int } +#[inline(always)] +pub unsafe fn lua_isinteger64(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TINTEGER) as c_int +} + #[inline(always)] pub unsafe fn lua_isvector(L: *mut lua_State, n: c_int) -> c_int { (lua_type(L, n) == LUA_TVECTOR) as c_int @@ -349,6 +430,11 @@ pub unsafe fn lua_isthread(L: *mut lua_State, n: c_int) -> c_int { (lua_type(L, n) == LUA_TTHREAD) as c_int } +#[inline(always)] +pub unsafe fn lua_isbuffer(L: *mut lua_State, n: c_int) -> c_int { + (lua_type(L, n) == LUA_TBUFFER) as c_int +} + #[inline(always)] pub unsafe fn lua_isnone(L: *mut lua_State, n: c_int) -> c_int { (lua_type(L, n) == LUA_TNONE) as c_int @@ -360,33 +446,35 @@ pub unsafe fn lua_isnoneornil(L: *mut lua_State, n: c_int) -> c_int { } #[inline(always)] -pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static str) { - use std::ffi::CString; - let c_str = CString::new(s).unwrap(); - lua_pushlstring_(L, c_str.as_ptr(), c_str.as_bytes().len()) +pub unsafe fn lua_pushliteral(L: *mut lua_State, s: &'static CStr) { + lua_pushstring_(L, s.as_ptr()); } +#[inline(always)] pub unsafe fn lua_pushcfunction(L: *mut lua_State, f: lua_CFunction) { lua_pushcclosurek(L, f, ptr::null(), 0, None) } +#[inline(always)] pub unsafe fn lua_pushcfunctiond(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char) { lua_pushcclosurek(L, f, debugname, 0, None) } +#[inline(always)] pub unsafe fn lua_pushcclosure(L: *mut lua_State, f: lua_CFunction, nup: c_int) { lua_pushcclosurek(L, f, ptr::null(), nup, None) } -pub unsafe fn lua_pushcclosured( - L: *mut lua_State, - f: lua_CFunction, - debugname: *const c_char, - nup: c_int, -) { +#[inline(always)] +pub unsafe fn lua_pushcclosured(L: *mut lua_State, f: lua_CFunction, debugname: *const c_char, nup: c_int) { lua_pushcclosurek(L, f, debugname, nup, None) } +#[inline(always)] +pub unsafe fn lua_pushlightuserdata(L: *mut lua_State, p: *mut c_void) { + lua_pushlightuserdatatagged(L, p, 0) +} + #[inline(always)] pub unsafe fn lua_setglobal(L: *mut lua_State, var: *const c_char) { lua_setfield(L, LUA_GLOBALSINDEX, var) @@ -410,9 +498,9 @@ pub unsafe fn lua_tostring(L: *mut lua_State, i: c_int) -> *const c_char { const LUA_IDSIZE: usize = 256; /// Type for functions to be called on debug events. -pub type lua_Hook = unsafe extern "C" fn(L: *mut lua_State, ar: *mut lua_Debug); +pub type lua_Hook = unsafe extern "C-unwind" fn(L: *mut lua_State, ar: *mut lua_Debug); -pub type lua_Coverage = unsafe extern "C" fn( +pub type lua_Coverage = unsafe extern "C-unwind" fn( context: *mut c_void, function: *const c_char, linedefined: c_int, @@ -421,14 +509,15 @@ pub type lua_Coverage = unsafe extern "C" fn( size: usize, ); -extern "C" { +pub type lua_CounterFunction = + unsafe extern "C-unwind" fn(context: *mut c_void, function: *const c_char, linedefined: c_int); + +pub type lua_CounterValue = + unsafe extern "C-unwind" fn(context: *mut c_void, kind: c_int, line: c_int, hits: u64); + +unsafe extern "C-unwind" { pub fn lua_stackdepth(L: *mut lua_State) -> c_int; - pub fn lua_getinfo( - L: *mut lua_State, - level: c_int, - what: *const c_char, - ar: *mut lua_Debug, - ) -> c_int; + pub fn lua_getinfo(L: *mut lua_State, level: c_int, what: *const c_char, ar: *mut lua_Debug) -> c_int; pub fn lua_getargument(L: *mut lua_State, level: c_int, n: c_int) -> c_int; pub fn lua_getlocal(L: *mut lua_State, level: c_int, n: c_int) -> *const c_char; pub fn lua_setlocal(L: *mut lua_State, level: c_int, n: c_int) -> *const c_char; @@ -436,13 +525,16 @@ extern "C" { pub fn lua_setupvalue(L: *mut lua_State, funcindex: c_int, n: c_int) -> *const c_char; pub fn lua_singlestep(L: *mut lua_State, enabled: c_int); - pub fn lua_breakpoint(L: *mut lua_State, funcindex: c_int, line: c_int, enabled: c_int); + pub fn lua_breakpoint(L: *mut lua_State, funcindex: c_int, line: c_int, enabled: c_int) -> c_int; + + pub fn lua_getcoverage(L: *mut lua_State, funcindex: c_int, context: *mut c_void, callback: lua_Coverage); - pub fn lua_getcoverage( + pub fn lua_getcounters( L: *mut lua_State, funcindex: c_int, context: *mut c_void, - callback: lua_Coverage, + functionvisit: lua_CounterFunction, + countervisit: lua_CounterValue, ); pub fn lua_debugtrace(L: *mut lua_State) -> *const c_char; @@ -453,13 +545,14 @@ pub struct lua_Debug { pub name: *const c_char, pub what: *const c_char, pub source: *const c_char, + pub short_src: *const c_char, pub linedefined: c_int, pub currentline: c_int, pub nupvals: u8, pub nparams: u8, pub isvararg: c_char, - pub short_src: [c_char; LUA_IDSIZE], pub userdata: *mut c_void, + pub ssbuf: [c_char; LUA_IDSIZE], } // @@ -468,30 +561,45 @@ pub struct lua_Debug { // #[repr(C)] +#[non_exhaustive] pub struct lua_Callbacks { /// arbitrary userdata pointer that is never overwritten by Luau pub userdata: *mut c_void, /// gets called at safepoints (loop back edges, call/ret, gc) if set - pub interrupt: Option, + pub interrupt: Option, /// gets called when an unprotected error is raised (if longjmp is used) - pub panic: Option, + pub panic: Option, /// gets called when L is created (LP == parent) or destroyed (LP == NULL) - pub userthread: Option, - /// gets called when a string is created; returned atom can be retrieved via tostringatom - pub useratom: Option i16>, + pub userthread: Option, + /// gets called when a string is created to assign an atom id + pub useratom: Option i16>, /// gets called when BREAK instruction is encountered - pub debugbreak: Option, + pub debugbreak: Option, /// gets called after each instruction in single step mode - pub debugstep: Option, + pub debugstep: Option, /// gets called when thread execution is interrupted by break in another thread - pub debuginterrupt: Option, + pub debuginterrupt: Option, /// gets called when protected call results in an error - pub debugprotectederror: Option, + pub debugprotectederror: Option, + + /// gets called when memory is allocated + pub onallocate: Option, } -extern "C" { +unsafe extern "C" { pub fn lua_callbacks(L: *mut lua_State) -> *mut lua_Callbacks; } + +// Functions from customization lib +unsafe extern "C" { + pub fn luau_setfflag(name: *const c_char, value: c_int) -> c_int; + pub fn lua_getmetatablepointer(L: *mut lua_State, idx: c_int) -> *const c_void; + pub fn lua_gcdump( + L: *mut lua_State, + file: *mut c_void, + category_name: Option *const c_char>, + ); +} diff --git a/mlua-sys/src/luau/luacode.rs b/mlua-sys/src/luau/luacode.rs new file mode 100644 index 00000000..1d74d453 --- /dev/null +++ b/mlua-sys/src/luau/luacode.rs @@ -0,0 +1,114 @@ +//! Contains definitions from `luacode.h`. + +use std::marker::{PhantomData, PhantomPinned}; +use std::os::raw::{c_char, c_int, c_void}; +use std::{ptr, slice}; + +#[repr(C)] +#[non_exhaustive] +pub struct lua_CompileOptions { + pub optimizationLevel: c_int, + pub debugLevel: c_int, + pub typeInfoLevel: c_int, + pub coverageLevel: c_int, + pub vectorLib: *const c_char, + pub vectorCtor: *const c_char, + pub vectorType: *const c_char, + pub mutableGlobals: *const *const c_char, + pub userdataTypes: *const *const c_char, + pub librariesWithKnownMembers: *const *const c_char, + pub libraryMemberTypeCallback: Option, + pub libraryMemberConstantCallback: Option, + pub disabledBuiltins: *const *const c_char, +} + +impl Default for lua_CompileOptions { + fn default() -> Self { + Self { + optimizationLevel: 1, + debugLevel: 1, + typeInfoLevel: 0, + coverageLevel: 0, + vectorLib: ptr::null(), + vectorCtor: ptr::null(), + vectorType: ptr::null(), + mutableGlobals: ptr::null(), + userdataTypes: ptr::null(), + librariesWithKnownMembers: ptr::null(), + libraryMemberTypeCallback: None, + libraryMemberConstantCallback: None, + disabledBuiltins: ptr::null(), + } + } +} + +#[repr(C)] +pub struct lua_CompileConstant { + _data: [u8; 0], + _marker: PhantomData<(*mut u8, PhantomPinned)>, +} + +/// Type table tags +#[doc(hidden)] +#[repr(i32)] +#[non_exhaustive] +pub enum luau_BytecodeType { + Nil = 0, + Boolean, + Number, + String, + Table, + Function, + Thread, + UserData, + Vector, + Buffer, + + Any = 15, +} + +pub type lua_LibraryMemberTypeCallback = + unsafe extern "C-unwind" fn(library: *const c_char, member: *const c_char) -> c_int; + +pub type lua_LibraryMemberConstantCallback = unsafe extern "C-unwind" fn( + library: *const c_char, + member: *const c_char, + constant: *mut lua_CompileConstant, +); + +unsafe extern "C" { + pub fn luau_set_compile_constant_nil(cons: *mut lua_CompileConstant); + pub fn luau_set_compile_constant_boolean(cons: *mut lua_CompileConstant, b: c_int); + pub fn luau_set_compile_constant_number(cons: *mut lua_CompileConstant, n: f64); + pub fn luau_set_compile_constant_integer64(cons: *mut lua_CompileConstant, l: i64); + pub fn luau_set_compile_constant_vector(cons: *mut lua_CompileConstant, x: f32, y: f32, z: f32, w: f32); + pub fn luau_set_compile_constant_string(cons: *mut lua_CompileConstant, s: *const c_char, l: usize); +} + +unsafe extern "C-unwind" { + #[link_name = "luau_compile"] + pub fn luau_compile_( + source: *const c_char, + size: usize, + options: *mut lua_CompileOptions, + outsize: *mut usize, + ) -> *mut c_char; +} + +unsafe extern "C" { + fn free(p: *mut c_void); +} + +pub unsafe fn luau_compile(source: &[u8], mut options: lua_CompileOptions) -> Vec { + let mut outsize = 0; + let data_ptr = luau_compile_( + source.as_ptr() as *const c_char, + source.len(), + &mut options, + &mut outsize, + ); + assert!(!data_ptr.is_null(), "luau_compile failed"); + let data = slice::from_raw_parts(data_ptr as *mut u8, outsize).to_vec(); + free(data_ptr as *mut c_void); + data +} diff --git a/mlua-sys/src/luau/luacodegen.rs b/mlua-sys/src/luau/luacodegen.rs new file mode 100644 index 00000000..9e063ed2 --- /dev/null +++ b/mlua-sys/src/luau/luacodegen.rs @@ -0,0 +1,11 @@ +//! Contains definitions from `luacodegen.h`. + +use std::os::raw::c_int; + +use super::lua::lua_State; + +unsafe extern "C-unwind" { + pub fn luau_codegen_supported() -> c_int; + pub fn luau_codegen_create(state: *mut lua_State); + pub fn luau_codegen_compile(state: *mut lua_State, idx: c_int); +} diff --git a/mlua-sys/src/luau/lualib.rs b/mlua-sys/src/luau/lualib.rs new file mode 100644 index 00000000..02ccf561 --- /dev/null +++ b/mlua-sys/src/luau/lualib.rs @@ -0,0 +1,35 @@ +//! Contains definitions from `lualib.h`. + +use std::os::raw::{c_char, c_int}; + +use super::lua::lua_State; + +pub const LUA_COLIBNAME: *const c_char = cstr!("coroutine"); +pub const LUA_TABLIBNAME: *const c_char = cstr!("table"); +pub const LUA_OSLIBNAME: *const c_char = cstr!("os"); +pub const LUA_STRLIBNAME: *const c_char = cstr!("string"); +pub const LUA_BITLIBNAME: *const c_char = cstr!("bit32"); +pub const LUA_BUFFERLIBNAME: *const c_char = cstr!("buffer"); +pub const LUA_UTF8LIBNAME: *const c_char = cstr!("utf8"); +pub const LUA_MATHLIBNAME: *const c_char = cstr!("math"); +pub const LUA_DBLIBNAME: *const c_char = cstr!("debug"); +pub const LUA_VECLIBNAME: *const c_char = cstr!("vector"); +pub const LUA_INTLIBNAME: *const c_char = cstr!("integer"); + +unsafe extern "C-unwind" { + pub fn luaopen_base(L: *mut lua_State) -> c_int; + pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; + pub fn luaopen_table(L: *mut lua_State) -> c_int; + pub fn luaopen_os(L: *mut lua_State) -> c_int; + pub fn luaopen_string(L: *mut lua_State) -> c_int; + pub fn luaopen_bit32(L: *mut lua_State) -> c_int; + pub fn luaopen_buffer(L: *mut lua_State) -> c_int; + pub fn luaopen_utf8(L: *mut lua_State) -> c_int; + pub fn luaopen_math(L: *mut lua_State) -> c_int; + pub fn luaopen_debug(L: *mut lua_State) -> c_int; + pub fn luaopen_vector(L: *mut lua_State) -> c_int; + pub fn luaopen_integer(L: *mut lua_State) -> c_int; + + // open all builtin libraries + pub fn luaL_openlibs(L: *mut lua_State); +} diff --git a/mlua-sys/src/luau/luarequire.rs b/mlua-sys/src/luau/luarequire.rs new file mode 100644 index 00000000..0574efd4 --- /dev/null +++ b/mlua-sys/src/luau/luarequire.rs @@ -0,0 +1,219 @@ +//! Contains definitions from `Require.h`. + +use std::os::raw::{c_char, c_int, c_void}; + +use super::lua::lua_State; + +pub const LUA_REGISTERED_MODULES_TABLE: *const c_char = cstr!("_REGISTEREDMODULES"); + +#[repr(C)] +pub enum luarequire_NavigateResult { + Success, + Ambiguous, + NotFound, +} + +// Functions returning WriteSuccess are expected to set their size_out argument +// to the number of bytes written to the buffer. If WriteBufferTooSmall is +// returned, size_out should be set to the required buffer size. +#[repr(C)] +pub enum luarequire_WriteResult { + Success, + BufferTooSmall, + Failure, +} + +/// Represents whether a configuration file is present, and if so, its syntax. +#[repr(C)] +pub enum luarequire_ConfigStatus { + Absent, + // Signals the presence of multiple configuration files + Ambiguous, + PresentJson, + PresentLuau, +} + +#[repr(C)] +pub struct luarequire_Configuration { + // Returns whether requires are permitted from the given chunkname. + pub is_require_allowed: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + requirer_chunkname: *const c_char, + ) -> bool, + + // Resets the internal state to point at the requirer module. + pub reset: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + requirer_chunkname: *const c_char, + ) -> luarequire_NavigateResult, + + // Resets the internal state to point at an aliased module, given its exact path from a configuration + // file. This function is only called when an alias's path cannot be resolved relative to its + // configuration file. + pub jump_to_alias: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + path: *const c_char, + ) -> luarequire_NavigateResult, + + // Provides an initial alias override opportunity prior to searching for configuration files. + // If NAVIGATE_SUCCESS is returned, the internal state must be updated to point at the + // aliased location. + // Can be left undefined. + pub to_alias_override: Option< + unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + alias_unprefixed: *const c_char, + ) -> luarequire_NavigateResult, + >, + + // Provides a final override opportunity if an alias cannot be found in configuration files. If + // NAVIGATE_SUCCESS is returned, this must update the internal state to point at the aliased module. + // Can be left undefined. + pub to_alias_fallback: Option< + unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + alias_unprefixed: *const c_char, + ) -> luarequire_NavigateResult, + >, + + // Navigates through the context by making mutations to the internal state. + pub to_parent: + unsafe extern "C-unwind" fn(L: *mut lua_State, ctx: *mut c_void) -> luarequire_NavigateResult, + pub to_child: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + name: *const c_char, + ) -> luarequire_NavigateResult, + + // Returns whether the context is currently pointing at a module. + pub is_module_present: unsafe extern "C-unwind" fn(L: *mut lua_State, ctx: *mut c_void) -> bool, + + // Provides a chunkname for the current module. This will be accessible through the debug library. This + // function is only called if is_module_present returns true. + pub get_chunkname: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> luarequire_WriteResult, + + // Provides a loadname that identifies the current module and is passed to load. This function + // is only called if is_module_present returns true. + pub get_loadname: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> luarequire_WriteResult, + + // Provides a cache key representing the current module. This function is only called if + // is_module_present returns true. + pub get_cache_key: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> luarequire_WriteResult, + + // Returns whether a configuration file is present in the current context, and if so, its syntax. + // If not present, require-by-string will call to_parent until either a configuration file is present or + // NAVIGATE_FAILURE is returned (at root). + pub get_config_status: + unsafe extern "C-unwind" fn(L: *mut lua_State, ctx: *mut c_void) -> luarequire_ConfigStatus, + + // Parses the configuration file in the current context for the given alias and returns its + // value or WRITE_FAILURE if not found. This function is only called if get_config_status + // returns true. If this function pointer is set, get_config must not be set. Opting in to this + // function pointer disables parsing configuration files internally and can be used for finer + // control over the configuration file parsing process. + pub get_alias: Option< + unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + alias: *const c_char, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> luarequire_WriteResult, + >, + + // Provides the contents of the configuration file in the current context. + // This function is only called if get_config_status does not return CONFIG_ABSENT. If this function + // pointer is set, get_alias must not be set. Opting in to this function pointer enables parsing + // configuration files internally. + pub get_config: Option< + unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> luarequire_WriteResult, + >, + + // Returns the maximum number of milliseconds to allow for executing a given Luau-syntax configuration + // file. This function is only called if get_config_status returns CONFIG_PRESENT_LUAU and can be left + // undefined if support for Luau-syntax configuration files is not needed. A default value of 2000ms is + // used. Negative values are treated as infinite. + pub get_luau_config_timeout: + Option c_int>, + + // Executes the module and places the result on the stack. Returns the number of results placed on the + // stack. + // Returning -1 directs the requiring thread to yield. In this case, this thread should be resumed with + // the module result pushed onto its stack. + pub load: unsafe extern "C-unwind" fn( + L: *mut lua_State, + ctx: *mut c_void, + path: *const c_char, + chunkname: *const c_char, + loadname: *const c_char, + ) -> c_int, +} + +// Populates function pointers in the given luarequire_Configuration. +pub type luarequire_Configuration_init = unsafe extern "C-unwind" fn(config: *mut luarequire_Configuration); + +unsafe extern "C-unwind" { + // Initializes and pushes the require closure onto the stack without registration. + pub fn luarequire_pushrequire( + L: *mut lua_State, + config_init: luarequire_Configuration_init, + ctx: *mut c_void, + ) -> c_int; + + // Initializes the require library and registers it globally. + pub fn luaopen_require(L: *mut lua_State, config_init: luarequire_Configuration_init, ctx: *mut c_void); + + // Initializes and pushes a "proxyrequire" closure onto the stack. + // + // The closure takes two parameters: the string path to resolve and the chunkname of an existing + // module. + pub fn luarequire_pushproxyrequire( + L: *mut lua_State, + config_init: luarequire_Configuration_init, + ctx: *mut c_void, + ) -> c_int; + + // Registers an aliased require path to a result. + // + // After registration, the given result will always be immediately returned when the given path is + // required. + // Expects the path and table to be passed as arguments on the stack. + pub fn luarequire_registermodule(L: *mut lua_State) -> c_int; + + // Clears the entry associated with the given cache key from the require cache. + // Expects the cache key to be passed as an argument on the stack. + pub fn luarequire_clearcacheentry(L: *mut lua_State) -> c_int; + + // Clears all entries from the require cache. + pub fn luarequire_clearcache(L: *mut lua_State) -> c_int; +} diff --git a/src/ffi/luau/mod.rs b/mlua-sys/src/luau/mod.rs similarity index 70% rename from src/ffi/luau/mod.rs rename to mlua-sys/src/luau/mod.rs index e46093c3..ea882a99 100644 --- a/src/ffi/luau/mod.rs +++ b/mlua-sys/src/luau/mod.rs @@ -4,10 +4,14 @@ pub use compat::*; pub use lauxlib::*; pub use lua::*; pub use luacode::*; +pub use luacodegen::*; pub use lualib::*; +pub use luarequire::*; pub mod compat; pub mod lauxlib; pub mod lua; pub mod luacode; +pub mod luacodegen; pub mod lualib; +pub mod luarequire; diff --git a/mlua-sys/src/macros.rs b/mlua-sys/src/macros.rs new file mode 100644 index 00000000..263b4e76 --- /dev/null +++ b/mlua-sys/src/macros.rs @@ -0,0 +1,6 @@ +#[allow(unused_macros)] +macro_rules! cstr { + ($s:expr) => { + concat!($s, "\0") as *const str as *const [::std::os::raw::c_char] as *const ::std::os::raw::c_char + }; +} diff --git a/mlua_derive/Cargo.toml b/mlua_derive/Cargo.toml index 0554cade..74d3c1ad 100644 --- a/mlua_derive/Cargo.toml +++ b/mlua_derive/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "mlua_derive" -version = "0.8.0-beta.1" +version = "0.11.0" authors = ["Aleksandr Orlenko "] -edition = "2018" +edition = "2021" description = "Procedural macros for the mlua crate." -repository = "https://github.com/khvzak/mlua" +repository = "https://github.com/mlua-rs/mlua" keywords = ["lua", "mlua"] license = "MIT" @@ -12,13 +12,13 @@ license = "MIT" proc-macro = true [features] -macros = ["proc-macro-error", "itertools", "regex", "once_cell"] +macros = ["proc-macro-error2", "itertools", "regex", "once_cell"] [dependencies] quote = "1.0" proc-macro2 = { version = "1.0", features = ["span-locations"] } -proc-macro-error = { version = "1.0", optional = true } -syn = { version = "1.0", features = ["full"] } -itertools = { version = "0.10", optional = true } +proc-macro-error2 = { version = "2.0.1", optional = true } +syn = { version = "2.0", features = ["full"] } +itertools = { version = "0.14", optional = true } regex = { version = "1.4", optional = true } once_cell = { version = "1.0", optional = true } diff --git a/mlua_derive/src/from_lua.rs b/mlua_derive/src/from_lua.rs new file mode 100644 index 00000000..e74eb868 --- /dev/null +++ b/mlua_derive/src/from_lua.rs @@ -0,0 +1,31 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +pub fn from_lua(input: TokenStream) -> TokenStream { + let DeriveInput { ident, generics, .. } = parse_macro_input!(input as DeriveInput); + + let ident_str = ident.to_string(); + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let where_clause = match &generics.where_clause { + Some(where_clause) => quote! { #where_clause, Self: 'static + Clone }, + None => quote! { where Self: 'static + Clone }, + }; + + quote! { + impl #impl_generics ::mlua::FromLua for #ident #ty_generics #where_clause { + #[inline] + fn from_lua(value: ::mlua::Value, _: &::mlua::Lua) -> ::mlua::Result { + match value { + ::mlua::Value::UserData(ud) => Ok(ud.borrow::()?.clone()), + _ => Err(::mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: #ident_str.to_string(), + message: None, + }), + } + } + } + } + .into() +} diff --git a/mlua_derive/src/lib.rs b/mlua_derive/src/lib.rs index afa8b547..f7d04803 100644 --- a/mlua_derive/src/lib.rs +++ b/mlua_derive/src/lib.rs @@ -1,38 +1,73 @@ use proc_macro::TokenStream; use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{parse_macro_input, AttributeArgs, Error, ItemFn}; +use syn::meta::ParseNestedMeta; +use syn::{parse_macro_input, ItemFn, LitStr, Result}; #[cfg(feature = "macros")] use { crate::chunk::Chunk, proc_macro::TokenTree, proc_macro2::TokenStream as TokenStream2, - proc_macro_error::proc_macro_error, + proc_macro_error2::proc_macro_error, }; +#[derive(Default)] +struct ModuleAttributes { + name: Option, + skip_memory_check: bool, +} + +impl ModuleAttributes { + fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> { + if meta.path.is_ident("name") { + match meta.value() { + Ok(value) => { + self.name = Some(value.parse::()?.parse()?); + } + Err(_) => { + return Err(meta.error("`name` attribute must have a value")); + } + } + } else if meta.path.is_ident("skip_memory_check") { + if meta.value().is_ok() { + return Err(meta.error("`skip_memory_check` attribute have no values")); + } + self.skip_memory_check = true; + } else { + return Err(meta.error("unsupported module attribute")); + } + Ok(()) + } +} + #[proc_macro_attribute] pub fn lua_module(attr: TokenStream, item: TokenStream) -> TokenStream { - let args = parse_macro_input!(attr as AttributeArgs); - let func = parse_macro_input!(item as ItemFn); - - if !args.is_empty() { - let err = Error::new(Span::call_site(), "the macro does not support arguments") - .to_compile_error(); - return err.into(); + let mut args = ModuleAttributes::default(); + if !attr.is_empty() { + let args_parser = syn::meta::parser(|meta| args.parse(meta)); + parse_macro_input!(attr with args_parser); } - let func_name = func.sig.ident.clone(); - let ext_entrypoint_name = Ident::new(&format!("luaopen_{}", func_name), Span::call_site()); + let func = parse_macro_input!(item as ItemFn); + let func_name = &func.sig.ident; + let module_name = args.name.unwrap_or_else(|| func_name.clone()); + let ext_entrypoint_name = Ident::new(&format!("luaopen_{module_name}"), Span::call_site()); + let skip_memory_check = if args.skip_memory_check { + quote! { lua.skip_memory_check(true); } + } else { + quote! {} + }; let wrapped = quote! { - ::mlua::require_module_feature!(); + mlua::require_module_feature!(); #func #[no_mangle] - unsafe extern "C" fn #ext_entrypoint_name(state: *mut ::mlua::lua_State) -> ::std::os::raw::c_int { - ::mlua::Lua::init_from_ptr(state) - .entrypoint1(#func_name) - .expect("cannot initialize module") + unsafe extern "C-unwind" fn #ext_entrypoint_name(state: *mut mlua::lua_State) -> ::std::os::raw::c_int { + mlua::Lua::entrypoint1(state, move |lua| { + #skip_memory_check + #func_name(lua) + }) } }; @@ -61,30 +96,21 @@ pub fn chunk(input: TokenStream) -> TokenStream { }); let wrapped_code = quote! {{ - use ::mlua::{AsChunk, ChunkMode, Lua, Result, Value}; + use mlua::{AsChunk, ChunkMode, Lua, Result, Table}; use ::std::borrow::Cow; + use ::std::cell::Cell; use ::std::io::Result as IoResult; - use ::std::marker::PhantomData; - use ::std::sync::Mutex; - - fn annotate<'a, F: FnOnce(&'a Lua) -> Result>>(f: F) -> F { f } - struct InnerChunk<'a, F: FnOnce(&'a Lua) -> Result>>(Mutex>, PhantomData<&'a ()>); + struct InnerChunk Result
>(Cell>); - impl<'lua, F> AsChunk<'lua> for InnerChunk<'lua, F> + impl AsChunk for InnerChunk where - F: FnOnce(&'lua Lua) -> Result>, + F: FnOnce(&Lua) -> Result
, { - fn source(&self) -> IoResult> { - Ok(Cow::Borrowed((#source).as_bytes())) - } - - fn env(&self, lua: &'lua Lua) -> Result>> { + fn environment(&self, lua: &Lua) -> Result> { if #caps_len > 0 { - if let Ok(mut make_env) = self.0.lock() { - if let Some(make_env) = make_env.take() { - return make_env(lua).map(Some); - } + if let Some(make_env) = self.0.take() { + return make_env(lua).map(Some); } } Ok(None) @@ -93,29 +119,41 @@ pub fn chunk(input: TokenStream) -> TokenStream { fn mode(&self) -> Option { Some(ChunkMode::Text) } + + fn source<'a>(&self) -> IoResult> { + Ok(Cow::Borrowed((#source).as_bytes())) + } } - let make_env = annotate(move |lua: &Lua| -> Result { + let make_env = move |lua: &Lua| -> Result
{ let globals = lua.globals(); let env = lua.create_table()?; let meta = lua.create_table()?; - meta.raw_set("__index", globals.clone())?; - meta.raw_set("__newindex", globals)?; + meta.raw_set("__index", &globals)?; + meta.raw_set("__newindex", &globals)?; // Add captured variables #(#caps)* - env.set_metatable(Some(meta)); - Ok(Value::Table(env)) - }); + env.set_metatable(Some(meta))?; + Ok(env) + }; - &InnerChunk(Mutex::new(Some(make_env)), PhantomData) + InnerChunk(Cell::new(Some(make_env))) }}; wrapped_code.into() } +#[cfg(feature = "macros")] +#[proc_macro_derive(FromLua)] +pub fn from_lua(input: TokenStream) -> TokenStream { + from_lua::from_lua(input) +} + #[cfg(feature = "macros")] mod chunk; #[cfg(feature = "macros")] +mod from_lua; +#[cfg(feature = "macros")] mod token; diff --git a/mlua_derive/src/token.rs b/mlua_derive/src/token.rs index 95eba3d6..c6ce7c97 100644 --- a/mlua_derive/src/token.rs +++ b/mlua_derive/src/token.rs @@ -1,9 +1,6 @@ -use std::{ - cmp::{Eq, PartialEq}, - fmt::{self, Display, Formatter}, - iter::IntoIterator, - vec::IntoIter, -}; +use std::cmp::{Eq, PartialEq}; +use std::fmt::{self, Display, Formatter}; +use std::vec::IntoIter; use itertools::Itertools; use once_cell::sync::Lazy; @@ -48,10 +45,7 @@ fn span_pos(span: &Span) -> (Pos, Pos) { return fallback_span_pos(span); } - ( - Pos::new(start.line, start.column), - Pos::new(end.line, end.column), - ) + (Pos::new(start.line, start.column), Pos::new(end.line, end.column)) } fn parse_pos(span: &Span) -> Option<(usize, usize)> { @@ -59,7 +53,7 @@ fn parse_pos(span: &Span) -> Option<(usize, usize)> { static RE: Lazy = Lazy::new(|| Regex::new(r"bytes\(([0-9]+)\.\.([0-9]+)\)").unwrap()); - match RE.captures(&format!("{:?}", span)) { + match RE.captures(&format!("{span:?}")) { Some(caps) => match (caps.get(1), caps.get(2)) { (Some(start), Some(end)) => Some(( match start.as_str().parse() { @@ -80,9 +74,7 @@ fn parse_pos(span: &Span) -> Option<(usize, usize)> { fn fallback_span_pos(span: &Span) -> (Pos, Pos) { let (start, end) = match parse_pos(span) { Some(v) => v, - None => proc_macro_error::abort_call_site!( - "Cannot retrieve span information; please use nightly" - ), + None => proc_macro_error2::abort_call_site!("Cannot retrieve span information; please use nightly"), }; (Pos::new(1, start), Pos::new(1, end)) } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..ac702a58 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +imports_granularity = "Module" +max_width = 110 +comment_width = 100 +wrap_comments = true diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 00000000..d27f8b3d --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,167 @@ +use std::io; + +#[cfg(feature = "serde")] +use serde::ser::{Serialize, Serializer}; + +use crate::state::RawLua; +use crate::types::ValueRef; + +/// A Luau buffer type. +/// +/// See the buffer [documentation] for more information. +/// +/// [documentation]: https://luau.org/library#buffer-library +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +#[derive(Clone, Debug, PartialEq)] +pub struct Buffer(pub(crate) ValueRef); + +#[cfg_attr(not(feature = "luau"), allow(unused))] +impl Buffer { + /// Copies the buffer data into a new `Vec`. + pub fn to_vec(&self) -> Vec { + let lua = self.0.lua.lock(); + self.as_slice(&lua).to_vec() + } + + /// Returns the length of the buffer. + pub fn len(&self) -> usize { + let lua = self.0.lua.lock(); + self.as_slice(&lua).len() + } + + /// Returns `true` if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Reads given number of bytes from the buffer at the given offset. + /// + /// Offset is 0-based. + #[track_caller] + pub fn read_bytes(&self, offset: usize) -> [u8; N] { + let lua = self.0.lua.lock(); + let data = self.as_slice(&lua); + let mut bytes = [0u8; N]; + bytes.copy_from_slice(&data[offset..offset + N]); + bytes + } + + /// Writes given bytes to the buffer at the given offset. + /// + /// Offset is 0-based. + #[track_caller] + pub fn write_bytes(&self, offset: usize, bytes: &[u8]) { + let lua = self.0.lua.lock(); + let data = self.as_slice_mut(&lua); + data[offset..offset + bytes.len()].copy_from_slice(bytes); + } + + /// Returns an adaptor implementing [`io::Read`], [`io::Write`] and [`io::Seek`] over the + /// buffer. + /// + /// Buffer operations are infallible, none of the read/write functions will return an Err. + pub fn cursor(self) -> impl io::Read + io::Write + io::Seek { + BufferCursor(self, 0) + } + + pub(crate) fn as_slice(&self, lua: &RawLua) -> &[u8] { + unsafe { + let (buf, size) = self.as_raw_parts(lua); + std::slice::from_raw_parts(buf, size) + } + } + + #[allow(clippy::mut_from_ref)] + fn as_slice_mut(&self, lua: &RawLua) -> &mut [u8] { + unsafe { + let (buf, size) = self.as_raw_parts(lua); + std::slice::from_raw_parts_mut(buf, size) + } + } + + #[cfg(feature = "luau")] + unsafe fn as_raw_parts(&self, lua: &RawLua) -> (*mut u8, usize) { + let mut size = 0usize; + let buf = ffi::lua_tobuffer(lua.ref_thread(), self.0.index, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + (buf as *mut u8, size) + } + + #[cfg(not(feature = "luau"))] + unsafe fn as_raw_parts(&self, lua: &RawLua) -> (*mut u8, usize) { + unreachable!() + } +} + +struct BufferCursor(Buffer, usize); + +impl io::Read for BufferCursor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let lua = self.0.0.lua.lock(); + let data = self.0.as_slice(&lua); + if self.1 == data.len() { + return Ok(0); + } + let len = buf.len().min(data.len() - self.1); + buf[..len].copy_from_slice(&data[self.1..self.1 + len]); + self.1 += len; + Ok(len) + } +} + +impl io::Write for BufferCursor { + fn write(&mut self, buf: &[u8]) -> io::Result { + let lua = self.0.0.lua.lock(); + let data = self.0.as_slice_mut(&lua); + if self.1 == data.len() { + return Ok(0); + } + let len = buf.len().min(data.len() - self.1); + data[self.1..self.1 + len].copy_from_slice(&buf[..len]); + self.1 += len; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl io::Seek for BufferCursor { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + let lua = self.0.0.lua.lock(); + let data = self.0.as_slice(&lua); + let new_offset = match pos { + io::SeekFrom::Start(offset) => offset as i64, + io::SeekFrom::End(offset) => data.len() as i64 + offset, + io::SeekFrom::Current(offset) => self.1 as i64 + offset, + }; + if new_offset < 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to a negative position", + )); + } + if new_offset as usize > data.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid seek to a position beyond the end of the buffer", + )); + } + self.1 = new_offset as usize; + Ok(self.1 as u64) + } +} + +#[cfg(feature = "serde")] +impl Serialize for Buffer { + fn serialize(&self, serializer: S) -> std::result::Result { + let lua = self.0.lua.lock(); + serializer.serialize_bytes(self.as_slice(&lua)) + } +} + +#[cfg(feature = "luau")] +impl crate::types::LuaType for Buffer { + const TYPE_ID: std::os::raw::c_int = ffi::LUA_TBUFFER; +} diff --git a/src/chunk.rs b/src/chunk.rs index 11617ff2..5c25f773 100644 --- a/src/chunk.rs +++ b/src/chunk.rs @@ -1,36 +1,40 @@ +//! Lua chunk loading and execution. +//! +//! This module provides types for loading Lua source code or bytecode into a [`Chunk`], +//! configuring how it is compiled and executed, and converting it into a callable [`Function`]. +//! +//! Chunks can be loaded from strings, byte slices, or files via the [`AsChunk`] trait. + use std::borrow::Cow; use std::collections::HashMap; use std::ffi::CString; use std::io::Result as IoResult; +use std::panic::Location; use std::path::{Path, PathBuf}; -use std::string::String as StdString; use crate::error::{Error, Result}; -use crate::ffi; use crate::function::Function; -use crate::lua::Lua; -use crate::value::{FromLuaMulti, ToLua, ToLuaMulti, Value}; - -#[cfg(feature = "async")] -use {futures_core::future::LocalBoxFuture, futures_util::future}; +use crate::state::{Lua, WeakLua}; +use crate::table::Table; +use crate::traits::{FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::value::Value; /// Trait for types [loadable by Lua] and convertible to a [`Chunk`] /// /// [loadable by Lua]: https://www.lua.org/manual/5.4/manual.html#3.3.2 -/// [`Chunk`]: crate::Chunk -pub trait AsChunk<'lua> { - /// Returns chunk data (can be text or binary) - fn source(&self) -> IoResult>; - +pub trait AsChunk { /// Returns optional chunk name - fn name(&self) -> Option { + /// + /// See [`Chunk::set_name`] for possible name prefixes. + fn name(&self) -> Option { None } /// Returns optional chunk [environment] /// /// [environment]: https://www.lua.org/manual/5.4/manual.html#2.2 - fn env(&self, _lua: &'lua Lua) -> Result>> { + fn environment(&self, lua: &Lua) -> Result> { + let _lua = lua; // suppress warning Ok(None) } @@ -38,62 +42,110 @@ pub trait AsChunk<'lua> { fn mode(&self) -> Option { None } + + /// Returns chunk data (can be text or binary) + fn source<'a>(&self) -> IoResult> + where + Self: 'a; +} + +impl AsChunk for &str { + fn source<'a>(&self) -> IoResult> + where + Self: 'a, + { + Ok(Cow::Borrowed(self.as_bytes())) + } } -impl<'lua> AsChunk<'lua> for str { - fn source(&self) -> IoResult> { - Ok(Cow::Borrowed(self.as_ref())) +impl AsChunk for String { + fn source<'a>(&self) -> IoResult> { + Ok(Cow::Owned(self.clone().into_bytes())) } } -impl<'lua> AsChunk<'lua> for StdString { - fn source(&self) -> IoResult> { - Ok(Cow::Borrowed(self.as_ref())) +impl AsChunk for &String { + fn source<'a>(&self) -> IoResult> + where + Self: 'a, + { + Ok(Cow::Borrowed(self.as_bytes())) } } -impl<'lua> AsChunk<'lua> for [u8] { - fn source(&self) -> IoResult> { +impl AsChunk for &[u8] { + fn source<'a>(&self) -> IoResult> + where + Self: 'a, + { Ok(Cow::Borrowed(self)) } } -impl<'lua> AsChunk<'lua> for Vec { - fn source(&self) -> IoResult> { +impl AsChunk for Vec { + fn source<'a>(&self) -> IoResult> { + Ok(Cow::Owned(self.clone())) + } +} + +impl AsChunk for &Vec { + fn source<'a>(&self) -> IoResult> + where + Self: 'a, + { Ok(Cow::Borrowed(self)) } } -impl<'lua> AsChunk<'lua> for Path { - fn source(&self) -> IoResult> { +impl AsChunk for &Path { + fn name(&self) -> Option { + Some(format!("@{}", self.display())) + } + + fn source<'a>(&self) -> IoResult> { std::fs::read(self).map(Cow::Owned) } +} - fn name(&self) -> Option { +impl AsChunk for PathBuf { + fn name(&self) -> Option { Some(format!("@{}", self.display())) } -} -impl<'lua> AsChunk<'lua> for PathBuf { - fn source(&self) -> IoResult> { + fn source<'a>(&self) -> IoResult> { std::fs::read(self).map(Cow::Owned) } +} - fn name(&self) -> Option { - Some(format!("@{}", self.display())) +impl AsChunk for Box { + fn name(&self) -> Option { + (**self).name() + } + + fn environment(&self, lua: &Lua) -> Result> { + (**self).environment(lua) + } + + fn mode(&self) -> Option { + (**self).mode() + } + + fn source<'a>(&self) -> IoResult> + where + Self: 'a, + { + (**self).source() } } /// Returned from [`Lua::load`] and is used to finalize loading and executing Lua main chunks. -/// -/// [`Lua::load`]: crate::Lua::load #[must_use = "`Chunk`s do nothing unless one of `exec`, `eval`, `call`, or `into_function` are called on them"] -pub struct Chunk<'lua, 'a> { - pub(crate) lua: &'lua Lua, - pub(crate) source: IoResult>, - pub(crate) name: Option, - pub(crate) env: Result>>, +pub struct Chunk<'a> { + pub(crate) lua: WeakLua, + pub(crate) name: String, + pub(crate) env: Result>, pub(crate) mode: Option, + pub(crate) source: IoResult>, #[cfg(feature = "luau")] pub(crate) compiler: Option, } @@ -105,6 +157,50 @@ pub enum ChunkMode { Binary, } +/// Represents a constant value that can be used by Luau compiler. +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum CompileConstant { + Nil, + Boolean(bool), + Number(crate::Number), + Vector(crate::Vector), + String(String), +} + +#[cfg(any(feature = "luau", doc))] +impl From for CompileConstant { + fn from(b: bool) -> Self { + CompileConstant::Boolean(b) + } +} + +#[cfg(any(feature = "luau", doc))] +impl From for CompileConstant { + fn from(n: crate::Number) -> Self { + CompileConstant::Number(n) + } +} + +#[cfg(any(feature = "luau", doc))] +impl From for CompileConstant { + fn from(v: crate::Vector) -> Self { + CompileConstant::Vector(v) + } +} + +#[cfg(any(feature = "luau", doc))] +impl From<&str> for CompileConstant { + fn from(s: &str) -> Self { + CompileConstant::String(s.to_owned()) + } +} + +#[cfg(any(feature = "luau", doc))] +type LibraryMemberConstantMap = HashMap<(String, String), CompileConstant>; + /// Luau compiler #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] @@ -112,33 +208,45 @@ pub enum ChunkMode { pub struct Compiler { optimization_level: u8, debug_level: u8, + type_info_level: u8, coverage_level: u8, vector_lib: Option, vector_ctor: Option, + vector_type: Option, mutable_globals: Vec, + userdata_types: Vec, + libraries_with_known_members: Vec, + library_constants: Option, + disabled_builtins: Vec, } #[cfg(any(feature = "luau", doc))] impl Default for Compiler { fn default() -> Self { + const { Self::new() } + } +} + +#[cfg(any(feature = "luau", doc))] +impl Compiler { + /// Creates Luau compiler instance with default options + pub const fn new() -> Self { // Defaults are taken from luacode.h Compiler { optimization_level: 1, debug_level: 1, + type_info_level: 0, coverage_level: 0, vector_lib: None, vector_ctor: None, + vector_type: None, mutable_globals: Vec::new(), + userdata_types: Vec::new(), + libraries_with_known_members: Vec::new(), + library_constants: None, + disabled_builtins: Vec::new(), } } -} - -#[cfg(any(feature = "luau", doc))] -impl Compiler { - /// Creates Luau compiler instance with default options - pub fn new() -> Self { - Compiler::default() - } /// Sets Luau compiler optimization level. /// @@ -146,7 +254,8 @@ impl Compiler { /// * 0 - no optimization /// * 1 - baseline optimization level that doesn't prevent debuggability (default) /// * 2 - includes optimizations that harm debuggability such as inlining - pub fn set_optimization_level(mut self, level: u8) -> Self { + #[must_use] + pub const fn set_optimization_level(mut self, level: u8) -> Self { self.optimization_level = level; self } @@ -157,45 +266,139 @@ impl Compiler { /// * 0 - no debugging support /// * 1 - line info & function names only; sufficient for backtraces (default) /// * 2 - full debug info with local & upvalue names; necessary for debugger - pub fn set_debug_level(mut self, level: u8) -> Self { + #[must_use] + pub const fn set_debug_level(mut self, level: u8) -> Self { self.debug_level = level; self } + /// Sets Luau type information level used to guide native code generation decisions. + /// + /// Possible values: + /// * 0 - generate for native modules (default) + /// * 1 - generate for all modules + #[must_use] + pub const fn set_type_info_level(mut self, level: u8) -> Self { + self.type_info_level = level; + self + } + /// Sets Luau compiler code coverage level. /// /// Possible values: /// * 0 - no code coverage support (default) /// * 1 - statement coverage /// * 2 - statement and expression coverage (verbose) - pub fn set_coverage_level(mut self, level: u8) -> Self { + #[must_use] + pub const fn set_coverage_level(mut self, level: u8) -> Self { self.coverage_level = level; self } + /// Sets alternative global builtin to construct vectors, in addition to default builtin + /// `vector.create`. + /// + /// To set the library and method name, use the `lib.ctor` format. #[doc(hidden)] - pub fn set_vector_lib(mut self, lib: Option) -> Self { - self.vector_lib = lib; + #[must_use] + pub fn set_vector_ctor(mut self, ctor: impl Into) -> Self { + let ctor = ctor.into(); + let lib_ctor = ctor.split_once('.'); + self.vector_lib = lib_ctor.as_ref().map(|&(lib, _)| lib.to_owned()); + self.vector_ctor = (lib_ctor.as_ref()) + .map(|&(_, ctor)| ctor.to_owned()) + .or(Some(ctor)); self } + /// Sets alternative vector type name for type tables, in addition to default type `vector`. #[doc(hidden)] - pub fn set_vector_ctor(mut self, ctor: Option) -> Self { - self.vector_ctor = ctor; + #[must_use] + pub fn set_vector_type(mut self, r#type: impl Into) -> Self { + self.vector_type = Some(r#type.into()); + self + } + + /// Adds a mutable global. + /// + /// It disables the import optimization for fields accessed through it. + #[must_use] + pub fn add_mutable_global(mut self, global: impl Into) -> Self { + self.mutable_globals.push(global.into()); self } /// Sets a list of globals that are mutable. /// /// It disables the import optimization for fields accessed through these. - pub fn set_mutable_globals(mut self, globals: Vec) -> Self { - self.mutable_globals = globals; + #[must_use] + pub fn set_mutable_globals>(mut self, globals: impl IntoIterator) -> Self { + self.mutable_globals = globals.into_iter().map(|s| s.into()).collect(); + self + } + + /// Adds a userdata type to the list that will be included in the type information. + #[must_use] + pub fn add_userdata_type(mut self, r#type: impl Into) -> Self { + self.userdata_types.push(r#type.into()); + self + } + + /// Sets a list of userdata types that will be included in the type information. + #[must_use] + pub fn set_userdata_types>(mut self, types: impl IntoIterator) -> Self { + self.userdata_types = types.into_iter().map(|s| s.into()).collect(); + self + } + + /// Adds a constant for a known library member. + /// + /// The constants are used by the compiler to optimize the generated bytecode. + /// Optimization level must be at least 2 for this to have any effect. + /// + /// The `name` is a string in the format `lib.member`, where `lib` is the library name + /// and `member` is the member (constant) name. + #[must_use] + pub fn add_library_constant( + mut self, + name: impl AsRef, + r#const: impl Into, + ) -> Self { + let Some((lib, member)) = name.as_ref().split_once('.') else { + return self; + }; + let (lib, member) = (lib.to_owned(), member.to_owned()); + + if !self.libraries_with_known_members.contains(&lib) { + self.libraries_with_known_members.push(lib.clone()); + } + self.library_constants + .get_or_insert_default() + .insert((lib, member), r#const.into()); + self + } + + /// Adds a builtin that should be disabled. + #[must_use] + pub fn add_disabled_builtin(mut self, builtin: impl Into) -> Self { + self.disabled_builtins.push(builtin.into()); + self + } + + /// Sets a list of builtins that should be disabled. + #[must_use] + pub fn set_disabled_builtins>(mut self, builtins: impl IntoIterator) -> Self { + self.disabled_builtins = builtins.into_iter().map(|s| s.into()).collect(); self } /// Compiles the `source` into bytecode. - pub fn compile(&self, source: impl AsRef<[u8]>) -> Vec { - use std::os::raw::c_int; + /// + /// Returns [`Error::SyntaxError`] if the source code is invalid. + pub fn compile(&self, source: impl AsRef<[u8]>) -> Result> { + use std::cell::RefCell; + use std::ffi::CStr; + use std::os::raw::{c_char, c_int}; use std::ptr; let vector_lib = self.vector_lib.clone(); @@ -204,50 +407,130 @@ impl Compiler { let vector_ctor = self.vector_ctor.clone(); let vector_ctor = vector_ctor.and_then(|ctor| CString::new(ctor).ok()); let vector_ctor = vector_ctor.as_ref(); + let vector_type = self.vector_type.clone(); + let vector_type = vector_type.and_then(|t| CString::new(t).ok()); + let vector_type = vector_type.as_ref(); + + macro_rules! vec2cstring_ptr { + ($name:ident, $name_ptr:ident) => { + let $name = self + .$name + .iter() + .map(|name| CString::new(name.clone()).ok()) + .collect::>>() + .unwrap_or_default(); + let mut $name = $name.iter().map(|s| s.as_ptr()).collect::>(); + let mut $name_ptr = ptr::null(); + if !$name.is_empty() { + $name.push(ptr::null()); + $name_ptr = $name.as_ptr(); + } + }; + } + + vec2cstring_ptr!(mutable_globals, mutable_globals_ptr); + vec2cstring_ptr!(userdata_types, userdata_types_ptr); + vec2cstring_ptr!(libraries_with_known_members, libraries_with_known_members_ptr); + vec2cstring_ptr!(disabled_builtins, disabled_builtins_ptr); - let mutable_globals = self - .mutable_globals - .iter() - .map(|name| CString::new(name.clone()).ok()) - .collect::>>() - .unwrap_or_default(); - let mut mutable_globals = mutable_globals - .iter() - .map(|s| s.as_ptr()) - .collect::>(); - let mut mutable_globals_ptr = ptr::null_mut(); - if !mutable_globals.is_empty() { - mutable_globals.push(ptr::null()); - mutable_globals_ptr = mutable_globals.as_mut_ptr(); + thread_local! { + static LIBRARY_MEMBER_CONSTANT_MAP: RefCell = Default::default(); } - unsafe { - let options = ffi::lua_CompileOptions { - optimizationLevel: self.optimization_level as c_int, - debugLevel: self.debug_level as c_int, - coverageLevel: self.coverage_level as c_int, - vectorLib: vector_lib.map_or(ptr::null(), |s| s.as_ptr()), - vectorCtor: vector_ctor.map_or(ptr::null(), |s| s.as_ptr()), - mutableGlobals: mutable_globals_ptr, - }; + #[cfg(feature = "luau")] + unsafe extern "C-unwind" fn library_member_constant_callback( + library: *const c_char, + member: *const c_char, + constant: *mut ffi::lua_CompileConstant, + ) { + let library = CStr::from_ptr(library).to_string_lossy(); + let member = CStr::from_ptr(member).to_string_lossy(); + LIBRARY_MEMBER_CONSTANT_MAP.with_borrow(|map| { + if let Some(cons) = map.get(&(library.to_string(), member.to_string())) { + match cons { + CompileConstant::Nil => ffi::luau_set_compile_constant_nil(constant), + CompileConstant::Boolean(b) => { + ffi::luau_set_compile_constant_boolean(constant, *b as c_int) + } + CompileConstant::Number(n) => ffi::luau_set_compile_constant_number(constant, *n), + CompileConstant::Vector(v) => { + #[cfg(not(feature = "luau-vector4"))] + ffi::luau_set_compile_constant_vector(constant, v.x(), v.y(), v.z(), 0.0); + #[cfg(feature = "luau-vector4")] + ffi::luau_set_compile_constant_vector(constant, v.x(), v.y(), v.z(), v.w()); + } + CompileConstant::String(s) => ffi::luau_set_compile_constant_string( + constant, + s.as_ptr() as *const c_char, + s.len(), + ), + } + } + }) + } + + let bytecode = unsafe { + let mut options = ffi::lua_CompileOptions::default(); + options.optimizationLevel = self.optimization_level as c_int; + options.debugLevel = self.debug_level as c_int; + options.typeInfoLevel = self.type_info_level as c_int; + options.coverageLevel = self.coverage_level as c_int; + options.vectorLib = vector_lib.map_or(ptr::null(), |s| s.as_ptr()); + options.vectorCtor = vector_ctor.map_or(ptr::null(), |s| s.as_ptr()); + options.vectorType = vector_type.map_or(ptr::null(), |s| s.as_ptr()); + options.mutableGlobals = mutable_globals_ptr; + options.userdataTypes = userdata_types_ptr; + options.librariesWithKnownMembers = libraries_with_known_members_ptr; + if let Some(map) = self.library_constants.as_ref() + && !self.libraries_with_known_members.is_empty() + { + LIBRARY_MEMBER_CONSTANT_MAP.with_borrow_mut(|gmap| *gmap = map.clone()); + options.libraryMemberConstantCallback = Some(library_member_constant_callback); + } + options.disabledBuiltins = disabled_builtins_ptr; ffi::luau_compile(source.as_ref(), options) + }; + + if bytecode.first() == Some(&0) { + // The rest of the bytecode is the error message starting with `:` + // See https://github.com/luau-lang/luau/blob/0.640/Compiler/src/Compiler.cpp#L4336 + let message = String::from_utf8_lossy(&bytecode[2..]).into_owned(); + return Err(Error::SyntaxError { + incomplete_input: message.ends_with(""), + message, + }); } + + Ok(bytecode) } } -impl<'lua, 'a> Chunk<'lua, 'a> { +impl Chunk<'_> { + /// Returns the name of this chunk. + pub fn name(&self) -> &str { + &self.name + } + /// Sets the name of this chunk, which results in more informative error traces. - pub fn set_name(mut self, name: impl AsRef) -> Result { - self.name = Some(name.as_ref().to_string()); - // Do extra validation - let _ = self.convert_name()?; - Ok(self) + /// + /// Possible name prefixes: + /// - `@` - file path (when truncation is needed, the end of the file path is kept, as this is + /// more useful for identifying the file) + /// - `=` - custom chunk name (when truncation is needed, the beginning of the name is kept) + pub fn set_name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self } - /// Sets the first upvalue (`_ENV`) of the loaded chunk to the given value. + /// Returns the environment of this chunk. + pub fn environment(&self) -> Option<&Table> { + self.env.as_ref().ok()?.as_ref() + } + + /// Sets the environment of the loaded chunk to the given value. /// - /// Lua main chunks always have exactly one upvalue, and this upvalue is used as the `_ENV` - /// variable inside the chunk. By default this value is set to the global environment. + /// In Lua >=5.2 main chunks always have exactly one upvalue, and this upvalue is used as the + /// `_ENV` variable inside the chunk. By default this value is set to the global environment. /// /// Calling this method changes the `_ENV` upvalue to the value provided, and variables inside /// the chunk will refer to the given environment rather than the global one. @@ -255,10 +538,14 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// All global variables (including the standard library!) are looked up in `_ENV`, so it may be /// necessary to populate the environment in order for scripts using custom environments to be /// useful. - pub fn set_environment>(mut self, env: V) -> Result { - // Prefer to propagate errors here and wrap to `Ok` - self.env = Ok(Some(env.to_lua(self.lua)?)); - Ok(self) + pub fn set_environment(mut self, env: Table) -> Self { + self.env = Ok(Some(env)); + self + } + + /// Returns the mode (auto-detected by default) of this chunk. + pub fn mode(&self) -> ChunkMode { + self.detect_mode() } /// Sets whether the chunk is text or binary (autodetected by default). @@ -273,8 +560,6 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// Sets or overwrites a Luau compiler used for this chunk. /// /// See [`Compiler`] for details and possible options. - /// - /// Requires `feature = "luau"` #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_compiler(mut self, compiler: Compiler) -> Self { @@ -286,24 +571,18 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// /// This is equivalent to calling the chunk function with no arguments and no return values. pub fn exec(self) -> Result<()> { - self.call(())?; - Ok(()) + self.call(()) } /// Asynchronously execute this chunk of code. /// /// See [`exec`] for more details. /// - /// Requires `feature = "async"` - /// - /// [`exec`]: #method.exec + /// [`exec`]: Chunk::exec #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn exec_async<'fut>(self) -> LocalBoxFuture<'fut, Result<()>> - where - 'lua: 'fut, - { - self.call_async(()) + pub async fn exec_async(self) -> Result<()> { + self.call_async(()).await } /// Evaluate the chunk as either an expression or block. @@ -311,7 +590,7 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// If the chunk can be parsed as an expression, this loads and executes the chunk and returns /// the value that it evaluates to. Otherwise, the chunk is interpreted as a block as normal, /// and this is equivalent to calling `exec`. - pub fn eval>(self) -> Result { + pub fn eval(self) -> Result { // Bytecode is always interpreted as a statement. // For source code, first try interpreting the lua as an expression by adding // "return", then as a statement. This is the same thing the @@ -329,29 +608,26 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// /// See [`eval`] for more details. /// - /// Requires `feature = "async"` - /// - /// [`eval`]: #method.eval + /// [`eval`]: Chunk::eval #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn eval_async<'fut, R>(self) -> LocalBoxFuture<'fut, Result> + pub async fn eval_async(self) -> Result where - 'lua: 'fut, - R: FromLuaMulti<'lua> + 'fut, + R: FromLuaMulti, { if self.detect_mode() == ChunkMode::Binary { - self.call_async(()) + self.call_async(()).await } else if let Ok(function) = self.to_expression() { - function.call_async(()) + function.call_async(()).await } else { - self.call_async(()) + self.call_async(()).await } } /// Load the chunk function and call it with the given arguments. /// /// This is equivalent to `into_function` and calling the resulting function. - pub fn call, R: FromLuaMulti<'lua>>(self, args: A) -> Result { + pub fn call(self, args: impl IntoLuaMulti) -> Result { self.into_function()?.call(args) } @@ -359,60 +635,50 @@ impl<'lua, 'a> Chunk<'lua, 'a> { /// /// See [`call`] for more details. /// - /// Requires `feature = "async"` - /// - /// [`call`]: #method.call + /// [`call`]: Chunk::call #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn call_async<'fut, A, R>(self, args: A) -> LocalBoxFuture<'fut, Result> + pub async fn call_async(self, args: impl IntoLuaMulti) -> Result where - 'lua: 'fut, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut, + R: FromLuaMulti, { - match self.into_function() { - Ok(func) => func.call_async(args), - Err(e) => Box::pin(future::err(e)), - } + self.into_function()?.call_async(args).await } - /// Load this chunk into a regular `Function`. + /// Load this chunk into a regular [`Function`]. /// /// This simply compiles the chunk without actually executing it. #[cfg_attr(not(feature = "luau"), allow(unused_mut))] - pub fn into_function(mut self) -> Result> { + pub fn into_function(mut self) -> Result { #[cfg(feature = "luau")] if self.compiler.is_some() { // We don't need to compile source if no compiler set self.compile(); } - let name = self.convert_name()?; + let name = Self::convert_name(self.name)?; self.lua - .load_chunk(self.source?.as_ref(), name.as_deref(), self.env?, self.mode) + .lock() + .load_chunk(Some(&name), self.env?.as_ref(), self.mode, self.source?.as_ref()) } /// Compiles the chunk and changes mode to binary. /// - /// It does nothing if the chunk is already binary. + /// It does nothing if the chunk is already binary or invalid. fn compile(&mut self) { - if let Ok(ref source) = self.source { - if self.detect_mode() == ChunkMode::Text { - #[cfg(feature = "luau")] - { - let data = self - .compiler - .get_or_insert_with(Default::default) - .compile(source); - self.source = Ok(Cow::Owned(data)); - self.mode = Some(ChunkMode::Binary); - } - #[cfg(not(feature = "luau"))] - if let Ok(func) = self.lua.load_chunk(source.as_ref(), None, None, None) { - let data = func.dump(false); - self.source = Ok(Cow::Owned(data)); - self.mode = Some(ChunkMode::Binary); - } + if let Ok(ref source) = self.source + && self.detect_mode() == ChunkMode::Text + { + #[cfg(feature = "luau")] + if let Ok(data) = self.compiler.get_or_insert_default().compile(source) { + self.source = Ok(Cow::Owned(data)); + self.mode = Some(ChunkMode::Binary); + } + #[cfg(not(feature = "luau"))] + if let Ok(func) = self.lua.lock().load_chunk(None, None, None, source.as_ref()) { + let data = func.dump(false); + self.source = Ok(Cow::Owned(data)); + self.mode = Some(ChunkMode::Binary); } } } @@ -425,31 +691,33 @@ impl<'lua, 'a> Chunk<'lua, 'a> { // Try to fetch compiled chunk from cache let mut text_source = None; - if let Ok(ref source) = self.source { - if self.detect_mode() == ChunkMode::Text { - if let Some(cache) = self.lua.app_data_ref::() { - if let Some(data) = cache.0.get(source.as_ref()) { - self.source = Ok(Cow::Owned(data.clone())); - self.mode = Some(ChunkMode::Binary); - return self; - } - } - text_source = Some(source.as_ref().to_vec()); + if let Ok(ref source) = self.source + && self.detect_mode() == ChunkMode::Text + { + let lua = self.lua.lock(); + if let Some(cache) = lua.priv_app_data_ref::() + && let Some(data) = cache.0.get(source.as_ref()) + { + self.source = Ok(Cow::Owned(data.clone())); + self.mode = Some(ChunkMode::Binary); + return self; } + text_source = Some(source.as_ref().to_vec()); } // Compile and cache the chunk if let Some(text_source) = text_source { self.compile(); - if let Ok(ref binary_source) = self.source { - if self.detect_mode() == ChunkMode::Binary { - if let Some(mut cache) = self.lua.app_data_mut::() { - cache.0.insert(text_source, binary_source.as_ref().to_vec()); - } else { - let mut cache = ChunksCache(HashMap::new()); - cache.0.insert(text_source, binary_source.as_ref().to_vec()); - self.lua.set_app_data(cache); - } + if let Ok(ref binary_source) = self.source + && self.detect_mode() == ChunkMode::Binary + { + let lua = self.lua.lock(); + if let Some(mut cache) = lua.priv_app_data_mut::() { + cache.0.insert(text_source, binary_source.to_vec()); + } else { + let mut cache = ChunksCache(HashMap::new()); + cache.0.insert(text_source, binary_source.to_vec()); + lua.set_priv_app_data(cache); } } } @@ -457,10 +725,10 @@ impl<'lua, 'a> Chunk<'lua, 'a> { self } - fn to_expression(&self) -> Result> { + fn to_expression(&self) -> Result { // We assume that mode is Text let source = self.source.as_ref(); - let source = source.map_err(|err| Error::RuntimeError(err.to_string()))?; + let source = source.map_err(Error::runtime)?; let source = Self::expression_source(source); // We don't need to compile source if no compiler options set #[cfg(feature = "luau")] @@ -468,37 +736,37 @@ impl<'lua, 'a> Chunk<'lua, 'a> { .compiler .as_ref() .map(|c| c.compile(&source)) + .transpose()? .unwrap_or(source); - let name = self.convert_name()?; - self.lua - .load_chunk(&source, name.as_deref(), self.env.clone()?, None) + let name = Self::convert_name(self.name.clone())?; + let env = match &self.env { + Ok(Some(env)) => Some(env), + Ok(None) => None, + Err(err) => return Err(err.clone()), + }; + self.lua.lock().load_chunk(Some(&name), env, None, &source) } fn detect_mode(&self) -> ChunkMode { - match (self.mode, &self.source) { - (Some(mode), _) => mode, - (None, Ok(source)) => { - #[cfg(not(feature = "luau"))] - if source.starts_with(ffi::LUA_SIGNATURE) { - return ChunkMode::Binary; - } - #[cfg(feature = "luau")] - if *source.get(0).unwrap_or(&u8::MAX) < b'\n' { - return ChunkMode::Binary; - } - ChunkMode::Text + if let Some(mode) = self.mode { + return mode; + } + if let Ok(source) = &self.source { + #[cfg(not(feature = "luau"))] + if source.starts_with(ffi::LUA_SIGNATURE) { + return ChunkMode::Binary; + } + #[cfg(feature = "luau")] + if *source.first().unwrap_or(&u8::MAX) < b'\n' { + return ChunkMode::Binary; } - (None, Err(_)) => ChunkMode::Text, // any value is fine } + ChunkMode::Text } - fn convert_name(&self) -> Result> { - self.name - .clone() - .map(CString::new) - .transpose() - .map_err(|err| Error::RuntimeError(format!("invalid name: {err}"))) + fn convert_name(name: String) -> Result { + CString::new(name).map_err(|err| Error::runtime(format!("invalid name: {err}"))) } fn expression_source(source: &[u8]) -> Vec { @@ -508,3 +776,30 @@ impl<'lua, 'a> Chunk<'lua, 'a> { buf } } + +struct WrappedChunk { + chunk: T, + caller: &'static Location<'static>, +} + +impl Chunk<'_> { + /// Wraps a chunk of Lua code, returning an opaque type that implements [`IntoLua`] trait. + /// + /// The resulted `IntoLua` implementation will convert the chunk into a Lua function without + /// executing it. + #[track_caller] + pub fn wrap(chunk: impl AsChunk) -> impl IntoLua { + WrappedChunk { + chunk, + caller: Location::caller(), + } + } +} + +impl IntoLua for WrappedChunk { + fn into_lua(self, lua: &Lua) -> Result { + lua.load_with_location(self.chunk, self.caller) + .into_function() + .map(Value::Function) + } +} diff --git a/src/conversion.rs b/src/conversion.rs index b9939c1c..1fdacb03 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -1,275 +1,550 @@ -#![allow(clippy::wrong_self_convention)] - use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::convert::TryInto; -use std::ffi::{CStr, CString}; +use std::ffi::{CStr, CString, OsStr, OsString}; use std::hash::{BuildHasher, Hash}; -use std::string::String as StdString; +use std::os::raw::c_int; +use std::path::{Path, PathBuf}; +use std::{slice, str}; -use bstr::{BStr, BString}; +use bstr::{BStr, BString, ByteVec}; use num_traits::cast; use crate::error::{Error, Result}; use crate::function::Function; -use crate::lua::Lua; -use crate::string::String; +use crate::state::{Lua, RawLua}; +use crate::string::{BorrowedBytes, BorrowedStr, LuaString}; use crate::table::Table; use crate::thread::Thread; -use crate::types::{LightUserData, MaybeSend}; +use crate::traits::{FromLua, IntoLua, ShortTypeName as _}; +use crate::types::{Either, LightUserData, MaybeSend, MaybeSync, RegistryKey}; use crate::userdata::{AnyUserData, UserData}; -use crate::value::{FromLua, Nil, ToLua, Value}; +use crate::value::{Nil, Value}; -impl<'lua> ToLua<'lua> for Value<'lua> { +impl IntoLua for Value { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(self) } } -impl<'lua> FromLua<'lua> for Value<'lua> { +impl IntoLua for &Value { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(self.clone()) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_value(self) + } +} + +impl FromLua for Value { #[inline] - fn from_lua(lua_value: Value<'lua>, _: &'lua Lua) -> Result { + fn from_lua(lua_value: Value, _: &Lua) -> Result { Ok(lua_value) } } -impl<'lua> ToLua<'lua> for String<'lua> { +impl IntoLua for LuaString { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::String(self)) } } -impl<'lua> FromLua<'lua> for String<'lua> { +impl IntoLua for &LuaString { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +impl FromLua for LuaString { #[inline] - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result> { + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); lua.coerce_string(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: "String", - message: Some("expected string or number".to_string()), - }) + .ok_or_else(|| Error::from_lua_conversion(ty, "string", "expected string or number".to_string())) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + let type_id = ffi::lua_type(state, idx); + if type_id == ffi::LUA_TSTRING { + ffi::lua_xpush(state, lua.ref_thread(), idx); + return Ok(LuaString(lua.pop_ref_thread())); + } + // Fallback to default + Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) + } +} + +impl IntoLua for BorrowedStr { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(LuaString(self.vref))) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.vref); + Ok(()) + } +} + +impl IntoLua for &BorrowedStr { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(LuaString(self.vref.clone()))) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.vref); + Ok(()) + } +} + +impl FromLua for BorrowedStr { + fn from_lua(value: Value, lua: &Lua) -> Result { + let s = LuaString::from_lua(value, lua)?; + BorrowedStr::try_from(&s) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let s = LuaString::from_stack(idx, lua)?; + BorrowedStr::try_from(&s) + } +} + +impl IntoLua for BorrowedBytes { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(LuaString(self.vref))) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.vref); + Ok(()) + } +} + +impl IntoLua for &BorrowedBytes { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(LuaString(self.vref.clone()))) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.vref); + Ok(()) + } +} + +impl FromLua for BorrowedBytes { + fn from_lua(value: Value, lua: &Lua) -> Result { + let s = LuaString::from_lua(value, lua)?; + Ok(BorrowedBytes::from(&s)) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let s = LuaString::from_stack(idx, lua)?; + Ok(BorrowedBytes::from(&s)) } } -impl<'lua> ToLua<'lua> for Table<'lua> { +impl IntoLua for Table { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::Table(self)) } } -impl<'lua> FromLua<'lua> for Table<'lua> { +impl IntoLua for &Table { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Table(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +impl FromLua for Table { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { + fn from_lua(value: Value, _: &Lua) -> Result
{ match value { Value::Table(table) => Ok(table), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "table", - message: None, - }), + _ => Err(Error::from_lua_conversion(value.type_name(), "table", None)), } } } -impl<'lua> ToLua<'lua> for Function<'lua> { +impl IntoLua for Function { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::Function(self)) } } -impl<'lua> FromLua<'lua> for Function<'lua> { +impl IntoLua for &Function { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Function(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +impl FromLua for Function { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { + fn from_lua(value: Value, _: &Lua) -> Result { match value { Value::Function(table) => Ok(table), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "function", - message: None, - }), + _ => Err(Error::from_lua_conversion(value.type_name(), "function", None)), } } } -impl<'lua> ToLua<'lua> for Thread<'lua> { +impl IntoLua for Thread { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::Thread(self)) } } -impl<'lua> FromLua<'lua> for Thread<'lua> { +impl IntoLua for &Thread { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Thread(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +impl FromLua for Thread { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { + fn from_lua(value: Value, _: &Lua) -> Result { match value { Value::Thread(t) => Ok(t), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "thread", - message: None, - }), + _ => Err(Error::from_lua_conversion(value.type_name(), "thread", None)), } } } -impl<'lua> ToLua<'lua> for AnyUserData<'lua> { +impl IntoLua for AnyUserData { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::UserData(self)) } } -impl<'lua> FromLua<'lua> for AnyUserData<'lua> { +impl IntoLua for &AnyUserData { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::UserData(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +impl FromLua for AnyUserData { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { match value { Value::UserData(ud) => Ok(ud), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "userdata", - message: None, - }), + _ => Err(Error::from_lua_conversion(value.type_name(), "userdata", None)), } } } -impl<'lua, T: 'static + MaybeSend + UserData> ToLua<'lua> for T { +impl IntoLua for T { #[inline] - fn to_lua(self, lua: &'lua Lua) -> Result> { + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::UserData(lua.create_userdata(self)?)) } } -impl<'lua, T: 'static + UserData + Clone> FromLua<'lua> for T { +impl IntoLua for Error { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Error(Box::new(self))) + } +} + +impl FromLua for Error { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + fn from_lua(value: Value, _: &Lua) -> Result { match value { - Value::UserData(ud) => Ok(ud.borrow::()?.clone()), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "userdata", - message: None, - }), + Value::Error(err) => Ok(*err), + val => Ok(Error::runtime(val.to_string()?)), } } } -impl<'lua> ToLua<'lua> for Error { +#[cfg(feature = "anyhow")] +impl IntoLua for anyhow::Error { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { - Ok(Value::Error(self)) + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Error(Box::new(Error::from(self)))) } } -impl<'lua> FromLua<'lua> for Error { +impl IntoLua for RegistryKey { #[inline] - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { - match value { - Value::Error(err) => Ok(err), - val => Ok(Error::RuntimeError( - lua.coerce_string(val)? - .and_then(|s| Some(s.to_str().ok()?.to_owned())) - .unwrap_or_else(|| "".to_owned()), - )), + fn into_lua(self, lua: &Lua) -> Result { + lua.registry_value(&self) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + <&RegistryKey>::push_into_stack(&self, lua) + } +} + +impl IntoLua for &RegistryKey { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + lua.registry_value(self) + } + + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + if !lua.owns_registry_value(self) { + return Err(Error::MismatchedRegistryKey); + } + + match self.id() { + ffi::LUA_REFNIL => ffi::lua_pushnil(lua.state()), + id => { + ffi::lua_rawgeti(lua.state(), ffi::LUA_REGISTRYINDEX, id as _); + } } + Ok(()) + } +} + +impl FromLua for RegistryKey { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { + lua.create_registry_value(value) } } -impl<'lua> ToLua<'lua> for bool { +impl IntoLua for bool { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::Boolean(self)) } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + ffi::lua_pushboolean(lua.state(), self as c_int); + Ok(()) + } } -impl<'lua> FromLua<'lua> for bool { +impl FromLua for bool { #[inline] - fn from_lua(v: Value<'lua>, _: &'lua Lua) -> Result { + fn from_lua(v: Value, _: &Lua) -> Result { match v { Value::Nil => Ok(false), Value::Boolean(b) => Ok(b), _ => Ok(true), } } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Ok(ffi::lua_toboolean(lua.state(), idx) != 0) + } } -impl<'lua> ToLua<'lua> for LightUserData { +impl IntoLua for LightUserData { #[inline] - fn to_lua(self, _: &'lua Lua) -> Result> { + fn into_lua(self, _: &Lua) -> Result { Ok(Value::LightUserData(self)) } } -impl<'lua> FromLua<'lua> for LightUserData { +impl FromLua for LightUserData { #[inline] - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { + fn from_lua(value: Value, _: &Lua) -> Result { match value { Value::LightUserData(ud) => Ok(ud), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "light userdata", - message: None, - }), + _ => Err(Error::from_lua_conversion( + value.type_name(), + "lightuserdata", + None, + )), } } } -impl<'lua> ToLua<'lua> for StdString { +#[cfg(feature = "luau")] +impl IntoLua for crate::Vector { #[inline] - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::String(lua.create_string(&self)?)) + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Vector(self)) } } -impl<'lua> FromLua<'lua> for StdString { +#[cfg(feature = "luau")] +impl FromLua for crate::Vector { #[inline] - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { + fn from_lua(value: Value, _: &Lua) -> Result { + match value { + Value::Vector(v) => Ok(v), + _ => Err(Error::from_lua_conversion(value.type_name(), "vector", None)), + } + } +} + +#[cfg(feature = "luau")] +impl IntoLua for crate::Buffer { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Buffer(self)) + } +} + +#[cfg(feature = "luau")] +impl IntoLua for &crate::Buffer { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Buffer(self.clone())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.0); + Ok(()) + } +} + +#[cfg(feature = "luau")] +impl FromLua for crate::Buffer { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { + match value { + Value::Buffer(buf) => Ok(buf), + _ => Err(Error::from_lua_conversion(value.type_name(), "buffer", None)), + } + } +} + +impl IntoLua for String { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + #[cfg(feature = "lua55")] + if true { + return Ok(Value::String(lua.create_external_string(self)?)); + } + + Ok(Value::String(lua.create_string(self)?)) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + #[cfg(feature = "lua55")] + if lua.unlikely_memory_error() { + return crate::util::push_external_string(lua.state(), self.into(), false); + } + + push_bytes_into_stack(self, lua) + } +} + +impl FromLua for String { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); Ok(lua .coerce_string(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: "String", - message: Some("expected string or number".to_string()), + .ok_or_else(|| { + Error::from_lua_conversion(ty, Self::type_name(), "expected string or number".to_string()) })? .to_str()? .to_owned()) } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + let type_id = ffi::lua_type(state, idx); + if type_id == ffi::LUA_TSTRING { + let mut size = 0; + let data = ffi::lua_tolstring(state, idx, &mut size); + let bytes = slice::from_raw_parts(data as *const u8, size); + return str::from_utf8(bytes) + .map(|s| s.to_owned()) + .map_err(|e| Error::from_lua_conversion("string", Self::type_name(), e.to_string())); + } + // Fallback to default + Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) + } } -impl<'lua> ToLua<'lua> for &str { +impl IntoLua for &str { #[inline] - fn to_lua(self, lua: &'lua Lua) -> Result> { + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::String(lua.create_string(self)?)) } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + push_bytes_into_stack(self, lua) + } } -impl<'lua> ToLua<'lua> for Cow<'_, str> { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::String(lua.create_string(self.as_bytes())?)) +impl IntoLua for Cow<'_, str> { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + match self { + Cow::Borrowed(s) => s.into_lua(lua), + Cow::Owned(s) => s.into_lua(lua), + } } } -impl<'lua> ToLua<'lua> for Box { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for Box { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::String(lua.create_string(&*self)?)) } } -impl<'lua> FromLua<'lua> for Box { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { +impl FromLua for Box { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); Ok(lua .coerce_string(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: "Box", - message: Some("expected string or number".to_string()), + .ok_or_else(|| { + Error::from_lua_conversion(ty, Self::type_name(), "expected string or number".to_string()) })? .to_str()? .to_owned() @@ -277,113 +552,275 @@ impl<'lua> FromLua<'lua> for Box { } } -impl<'lua> ToLua<'lua> for CString { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for CString { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + #[cfg(feature = "lua55")] + if true { + return Ok(Value::String(lua.create_external_string(self)?)); + } + Ok(Value::String(lua.create_string(self.as_bytes())?)) } } -impl<'lua> FromLua<'lua> for CString { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { +impl FromLua for CString { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); - let string = lua - .coerce_string(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: "CString", - message: Some("expected string or number".to_string()), - })?; - - match CStr::from_bytes_with_nul(string.as_bytes_with_nul()) { + let string = lua.coerce_string(value)?.ok_or_else(|| { + Error::from_lua_conversion(ty, Self::type_name(), "expected string or number".to_string()) + })?; + match CStr::from_bytes_with_nul(&string.as_bytes_with_nul()) { Ok(s) => Ok(s.into()), - Err(_) => Err(Error::FromLuaConversionError { - from: ty, - to: "CString", - message: Some("invalid C-style string".to_string()), - }), + Err(err) => Err(Error::from_lua_conversion(ty, Self::type_name(), err.to_string())), } } } -impl<'lua> ToLua<'lua> for &CStr { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for &CStr { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::String(lua.create_string(self.to_bytes())?)) } } -impl<'lua> ToLua<'lua> for Cow<'_, CStr> { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::String(lua.create_string(self.to_bytes())?)) +impl IntoLua for Cow<'_, CStr> { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + match self { + Cow::Borrowed(s) => s.into_lua(lua), + Cow::Owned(s) => s.into_lua(lua), + } } } -impl<'lua> ToLua<'lua> for BString { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::String(lua.create_string(&self)?)) +impl IntoLua for BString { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + #[cfg(feature = "lua55")] + if true { + return Ok(Value::String(lua.create_external_string(self)?)); + } + + Ok(Value::String(lua.create_string(self)?)) } } -impl<'lua> FromLua<'lua> for BString { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { +impl FromLua for BString { + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); - Ok(BString::from( - lua.coerce_string(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: "String", - message: Some("expected string or number".to_string()), + match value { + Value::String(s) => Ok((*s.as_bytes()).into()), + #[cfg(feature = "luau")] + Value::Buffer(buf) => Ok(buf.to_vec().into()), + _ => Ok((*lua + .coerce_string(value)? + .ok_or_else(|| { + Error::from_lua_conversion(ty, Self::type_name(), "expected string or number".to_string()) })? - .as_bytes() - .to_vec(), - )) + .as_bytes()) + .into()), + } + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + match ffi::lua_type(state, idx) { + ffi::LUA_TSTRING => { + let mut size = 0; + let data = ffi::lua_tolstring(state, idx, &mut size); + Ok(slice::from_raw_parts(data as *const u8, size).into()) + } + #[cfg(feature = "luau")] + ffi::LUA_TBUFFER => { + let mut size = 0; + let buf = ffi::lua_tobuffer(state, idx, &mut size); + mlua_assert!(!buf.is_null(), "invalid Luau buffer"); + Ok(slice::from_raw_parts(buf as *const u8, size).into()) + } + type_id => { + // Fallback to default + Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) + } + } + } +} + +impl IntoLua for &BStr { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + Ok(Value::String(lua.create_string(self)?)) } } -impl<'lua> ToLua<'lua> for &BStr { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::String(lua.create_string(&self)?)) +impl IntoLua for OsString { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + self.as_os_str().into_lua(lua) + } +} + +impl FromLua for OsString { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { + let ty = value.type_name(); + let bs = BString::from_lua(value, lua)?; + Vec::from(bs) + .into_os_string() + .map_err(|err| Error::from_lua_conversion(ty, "OsString", err.to_string())) + } +} + +impl IntoLua for &OsStr { + #[cfg(unix)] + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + use std::os::unix::ffi::OsStrExt; + Ok(Value::String(lua.create_string(self.as_bytes())?)) + } + + #[cfg(not(unix))] + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + self.display().to_string().into_lua(lua) + } +} + +impl IntoLua for PathBuf { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + self.as_os_str().into_lua(lua) + } +} + +impl FromLua for PathBuf { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { + OsString::from_lua(value, lua).map(PathBuf::from) + } +} + +impl IntoLua for &Path { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + self.as_os_str().into_lua(lua) + } +} + +impl IntoLua for char { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + let mut char_bytes = [0; 4]; + self.encode_utf8(&mut char_bytes); + Ok(Value::String(lua.create_string(&char_bytes[..self.len_utf8()])?)) } } +impl FromLua for char { + fn from_lua(value: Value, _lua: &Lua) -> Result { + let ty = value.type_name(); + match value { + Value::Integer(i) => cast(i).and_then(char::from_u32).ok_or_else(|| { + let msg = "integer out of range when converting to char"; + Error::from_lua_conversion(ty, "char", msg.to_string()) + }), + Value::String(s) => { + let str = s.to_str()?; + let mut str_iter = str.chars(); + match (str_iter.next(), str_iter.next()) { + (Some(char), None) => Ok(char), + _ => { + let msg = "expected string to have exactly one char when converting to char"; + Err(Error::from_lua_conversion(ty, "char", msg.to_string())) + } + } + } + _ => { + let msg = "expected string or integer"; + Err(Error::from_lua_conversion(ty, Self::type_name(), msg.to_string())) + } + } + } +} + +#[inline] +unsafe fn push_bytes_into_stack(this: T, lua: &RawLua) -> Result<()> +where + T: IntoLua + AsRef<[u8]>, +{ + let bytes = this.as_ref(); + if lua.unlikely_memory_error() && bytes.len() < (1 << 30) { + // Fast path: push directly into the Lua stack. + ffi::lua_pushlstring(lua.state(), bytes.as_ptr() as *const _, bytes.len()); + return Ok(()); + } + // Fallback to default + lua.push_value(&T::into_lua(this, lua.lua())?) +} + macro_rules! lua_convert_int { ($x:ty) => { - impl<'lua> ToLua<'lua> for $x { - fn to_lua(self, _: &'lua Lua) -> Result> { - cast(self) + impl IntoLua for $x { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(cast(self) .map(Value::Integer) - .or_else(|| cast(self).map(Value::Number)) - // This is impossible error because conversion to Number never fails - .ok_or_else(|| Error::ToLuaConversionError { - from: stringify!($x), - to: "number", - message: Some("out of range".to_owned()), - }) + .unwrap_or_else(|| Value::Number(self as ffi::lua_Number))) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + match cast(self) { + Some(i) => ffi::lua_pushinteger(lua.state(), i), + None => ffi::lua_pushnumber(lua.state(), self as ffi::lua_Number), + } + Ok(()) } } - impl<'lua> FromLua<'lua> for $x { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { + impl FromLua for $x { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); - (if let Value::Integer(i) = value { - cast(i) - } else if let Some(i) = lua.coerce_integer(value.clone())? { - cast(i) - } else { - cast(lua.coerce_number(value)?.ok_or_else(|| { - Error::FromLuaConversionError { - from: ty, - to: stringify!($x), - message: Some( - "expected number or string coercible to number".to_string(), - ), + (match value { + Value::Integer(i) => cast(i), + Value::Number(n) => cast(n), + _ => { + if let Some(i) = lua.coerce_integer(value.clone())? { + cast(i) + } else { + cast(lua.coerce_number(value)?.ok_or_else(|| { + let msg = "expected number or string coercible to number"; + Error::from_lua_conversion(ty, stringify!($x), msg.to_string()) + })?) } - })?) - }) - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: stringify!($x), - message: Some("out of range".to_owned()), + } }) + .ok_or_else(|| Error::from_lua_conversion(ty, stringify!($x), "out of range".to_string())) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + let type_id = ffi::lua_type(state, idx); + if type_id == ffi::LUA_TNUMBER { + let mut ok = 0; + let i = ffi::lua_tointegerx(state, idx, &mut ok); + if ok != 0 { + return cast(i).ok_or_else(|| { + Error::from_lua_conversion("integer", stringify!($x), "out of range".to_string()) + }); + } + } + #[cfg(feature = "luau")] + if type_id == ffi::LUA_TINTEGER { + let i = ffi::lua_tointeger64(state, idx, std::ptr::null_mut()); + return cast(i).ok_or_else(|| { + Error::from_lua_conversion("integer", stringify!($x), "out of range".to_string()) + }); + } + // Fallback to default + Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) } } }; @@ -404,34 +841,31 @@ lua_convert_int!(usize); macro_rules! lua_convert_float { ($x:ty) => { - impl<'lua> ToLua<'lua> for $x { - fn to_lua(self, _: &'lua Lua) -> Result> { - cast(self) - .ok_or_else(|| Error::ToLuaConversionError { - from: stringify!($x), - to: "number", - message: Some("out of range".to_string()), - }) - .map(Value::Number) + impl IntoLua for $x { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::Number(self as _)) } } - impl<'lua> FromLua<'lua> for $x { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { + impl FromLua for $x { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { let ty = value.type_name(); - lua.coerce_number(value)? - .ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: stringify!($x), - message: Some("expected number or string coercible to number".to_string()), - }) - .and_then(|n| { - cast(n).ok_or_else(|| Error::FromLuaConversionError { - from: ty, - to: stringify!($x), - message: Some("number out of range".to_string()), - }) - }) + lua.coerce_number(value)?.map(|n| n as $x).ok_or_else(|| { + let msg = "expected number or string coercible to number"; + Error::from_lua_conversion(ty, stringify!($x), msg.to_string()) + }) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + let type_id = ffi::lua_type(state, idx); + if type_id == ffi::LUA_TNUMBER { + return Ok(ffi::lua_tonumber(state, idx) as _); + } + // Fallback to default + Self::from_lua(lua.stack_value(idx, Some(type_id)), lua.lua()) } } }; @@ -440,208 +874,275 @@ macro_rules! lua_convert_float { lua_convert_float!(f32); lua_convert_float!(f64); -impl<'lua, T> ToLua<'lua> for &[T] +impl IntoLua for &[T] where - T: Clone + ToLua<'lua>, + T: IntoLua + Clone, { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::Table( - lua.create_sequence_from(self.iter().cloned())?, - )) + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + Ok(Value::Table(lua.create_sequence_from(self.iter().cloned())?)) } } -impl<'lua, T, const N: usize> ToLua<'lua> for [T; N] +impl IntoLua for [T; N] where - T: ToLua<'lua>, + T: IntoLua, { - fn to_lua(self, lua: &'lua Lua) -> Result> { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::Table(lua.create_sequence_from(self)?)) } } -impl<'lua, T, const N: usize> FromLua<'lua> for [T; N] +impl FromLua for [T; N] where - T: FromLua<'lua>, + T: FromLua, { - fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> Result { + #[inline] + fn from_lua(value: Value, _lua: &Lua) -> Result { match value { #[cfg(feature = "luau")] - Value::Vector(x, y, z) if N == 3 => Ok(mlua_expect!( - vec![ - T::from_lua(Value::Number(x as _), _lua)?, - T::from_lua(Value::Number(y as _), _lua)?, - T::from_lua(Value::Number(z as _), _lua)?, - ] - .try_into() - .map_err(|_| ()), - "cannot convert vector to array" - )), + #[rustfmt::skip] + Value::Vector(v) if N == crate::Vector::SIZE => unsafe { + use std::{mem, ptr}; + let mut arr: [mem::MaybeUninit; N] = mem::MaybeUninit::uninit().assume_init(); + ptr::write(arr[0].as_mut_ptr() , T::from_lua(Value::Number(v.x() as _), _lua)?); + ptr::write(arr[1].as_mut_ptr(), T::from_lua(Value::Number(v.y() as _), _lua)?); + ptr::write(arr[2].as_mut_ptr(), T::from_lua(Value::Number(v.z() as _), _lua)?); + #[cfg(feature = "luau-vector4")] + ptr::write(arr[3].as_mut_ptr(), T::from_lua(Value::Number(v.w() as _), _lua)?); + Ok(mem::transmute_copy(&arr)) + }, Value::Table(table) => { let vec = table.sequence_values().collect::>>()?; - vec.try_into() - .map_err(|vec: Vec| Error::FromLuaConversionError { - from: "Table", - to: "Array", - message: Some(format!("expected table of length {}, got {}", N, vec.len())), - }) + vec.try_into().map_err(|vec: Vec| { + let msg = format!("expected table of length {N}, got {}", vec.len()); + Error::from_lua_conversion("table", Self::type_name(), msg) + }) + } + _ => { + let msg = format!("expected table of length {N}"); + let err = Error::from_lua_conversion(value.type_name(), Self::type_name(), msg.to_string()); + Err(err) } - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "Array", - message: Some("expected table".to_string()), - }), } } } -impl<'lua, T: ToLua<'lua>> ToLua<'lua> for Box<[T]> { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for Box<[T]> { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::Table(lua.create_sequence_from(self.into_vec())?)) } } -impl<'lua, T: FromLua<'lua>> FromLua<'lua> for Box<[T]> { - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { +impl FromLua for Box<[T]> { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { Ok(Vec::::from_lua(value, lua)?.into_boxed_slice()) } } -impl<'lua, T: ToLua<'lua>> ToLua<'lua> for Vec { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for Vec { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::Table(lua.create_sequence_from(self)?)) } } -impl<'lua, T: FromLua<'lua>> FromLua<'lua> for Vec { - fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> Result { +impl FromLua for Vec { + #[inline] + fn from_lua(value: Value, _lua: &Lua) -> Result { match value { - #[cfg(feature = "luau")] - Value::Vector(x, y, z) => Ok(vec![ - T::from_lua(Value::Number(x as _), _lua)?, - T::from_lua(Value::Number(y as _), _lua)?, - T::from_lua(Value::Number(z as _), _lua)?, - ]), Value::Table(table) => table.sequence_values().collect(), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "Vec", - message: Some("expected table".to_string()), - }), + _ => Err(Error::from_lua_conversion( + value.type_name(), + Self::type_name(), + "expected table".to_string(), + )), } } } -impl<'lua, K: Eq + Hash + ToLua<'lua>, V: ToLua<'lua>, S: BuildHasher> ToLua<'lua> - for HashMap -{ - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for HashMap { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::Table(lua.create_table_from(self)?)) } } -impl<'lua, K: Eq + Hash + FromLua<'lua>, V: FromLua<'lua>, S: BuildHasher + Default> FromLua<'lua> - for HashMap -{ - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { - if let Value::Table(table) = value { - table.pairs().collect() - } else { - Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "HashMap", - message: Some("expected table".to_string()), - }) +impl FromLua for HashMap { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { + match value { + Value::Table(table) => table.pairs().collect(), + _ => Err(Error::from_lua_conversion( + value.type_name(), + Self::type_name(), + "expected table".to_string(), + )), } } } -impl<'lua, K: Ord + ToLua<'lua>, V: ToLua<'lua>> ToLua<'lua> for BTreeMap { - fn to_lua(self, lua: &'lua Lua) -> Result> { +impl IntoLua for BTreeMap { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { Ok(Value::Table(lua.create_table_from(self)?)) } } -impl<'lua, K: Ord + FromLua<'lua>, V: FromLua<'lua>> FromLua<'lua> for BTreeMap { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { - if let Value::Table(table) = value { - table.pairs().collect() - } else { - Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "BTreeMap", - message: Some("expected table".to_string()), - }) +impl FromLua for BTreeMap { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { + match value { + Value::Table(table) => table.pairs().collect(), + _ => Err(Error::from_lua_conversion( + value.type_name(), + Self::type_name(), + "expected table".to_string(), + )), } } } -impl<'lua, T: Eq + Hash + ToLua<'lua>, S: BuildHasher> ToLua<'lua> for HashSet { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::Table(lua.create_table_from( - self.into_iter().map(|val| (val, true)), - )?)) +impl IntoLua for HashSet { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + Ok(Value::Table( + lua.create_table_from(self.into_iter().map(|val| (val, true)))?, + )) } } -impl<'lua, T: Eq + Hash + FromLua<'lua>, S: BuildHasher + Default> FromLua<'lua> for HashSet { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { +impl FromLua for HashSet { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { match value { - Value::Table(table) if table.len()? > 0 => table.sequence_values().collect(), - Value::Table(table) => table - .pairs::>() - .map(|res| res.map(|(k, _)| k)) - .collect(), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "HashSet", - message: Some("expected table".to_string()), - }), + Value::Table(table) if table.raw_len() > 0 => table.sequence_values().collect(), + Value::Table(table) => table.pairs::().map(|res| res.map(|(k, _)| k)).collect(), + _ => Err(Error::from_lua_conversion( + value.type_name(), + Self::type_name(), + "expected table".to_string(), + )), } } } -impl<'lua, T: Ord + ToLua<'lua>> ToLua<'lua> for BTreeSet { - fn to_lua(self, lua: &'lua Lua) -> Result> { - Ok(Value::Table(lua.create_table_from( - self.into_iter().map(|val| (val, true)), - )?)) +impl IntoLua for BTreeSet { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + Ok(Value::Table( + lua.create_table_from(self.into_iter().map(|val| (val, true)))?, + )) } } -impl<'lua, T: Ord + FromLua<'lua>> FromLua<'lua> for BTreeSet { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { +impl FromLua for BTreeSet { + #[inline] + fn from_lua(value: Value, _: &Lua) -> Result { match value { - Value::Table(table) if table.len()? > 0 => table.sequence_values().collect(), - Value::Table(table) => table - .pairs::>() - .map(|res| res.map(|(k, _)| k)) - .collect(), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "BTreeSet", - message: Some("expected table".to_string()), - }), + Value::Table(table) if table.raw_len() > 0 => table.sequence_values().collect(), + Value::Table(table) => table.pairs::().map(|res| res.map(|(k, _)| k)).collect(), + _ => Err(Error::from_lua_conversion( + value.type_name(), + Self::type_name(), + "expected table".to_string(), + )), } } } -impl<'lua, T: ToLua<'lua>> ToLua<'lua> for Option { +impl IntoLua for Option { #[inline] - fn to_lua(self, lua: &'lua Lua) -> Result> { + fn into_lua(self, lua: &Lua) -> Result { match self { - Some(val) => val.to_lua(lua), + Some(val) => val.into_lua(lua), None => Ok(Nil), } } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + match self { + Some(val) => val.push_into_stack(lua)?, + None => ffi::lua_pushnil(lua.state()), + } + Ok(()) + } } -impl<'lua, T: FromLua<'lua>> FromLua<'lua> for Option { +impl FromLua for Option { #[inline] - fn from_lua(value: Value<'lua>, lua: &'lua Lua) -> Result { + fn from_lua(value: Value, lua: &Lua) -> Result { match value { Nil => Ok(None), value => Ok(Some(T::from_lua(value, lua)?)), } } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + match ffi::lua_type(lua.state(), idx) { + ffi::LUA_TNIL => Ok(None), + _ => Ok(Some(T::from_stack(idx, lua)?)), + } + } +} + +impl IntoLua for Either { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + match self { + Either::Left(l) => l.into_lua(lua), + Either::Right(r) => r.into_lua(lua), + } + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + match self { + Either::Left(l) => l.push_into_stack(lua), + Either::Right(r) => r.push_into_stack(lua), + } + } +} + +impl FromLua for Either { + #[inline] + fn from_lua(value: Value, lua: &Lua) -> Result { + let value_type_name = value.type_name(); + // Try the left type first + match L::from_lua(value.clone(), lua) { + Ok(l) => Ok(Either::Left(l)), + // Try the right type + Err(_) => match R::from_lua(value, lua).map(Either::Right) { + Ok(r) => Ok(r), + Err(_) => Err(Error::from_lua_conversion( + value_type_name, + Self::type_name(), + None, + )), + }, + } + } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + match L::from_stack(idx, lua) { + Ok(l) => Ok(Either::Left(l)), + Err(_) => match R::from_stack(idx, lua).map(Either::Right) { + Ok(r) => Ok(r), + Err(_) => { + let state = lua.state(); + let from_type_name = CStr::from_ptr(ffi::lua_typename(state, ffi::lua_type(state, idx))) + .to_str() + .unwrap_or("unknown"); + let err = Error::from_lua_conversion(from_type_name, Self::type_name(), None); + Err(err) + } + }, + } + } } diff --git a/src/debug.rs b/src/debug.rs new file mode 100644 index 00000000..89d8c501 --- /dev/null +++ b/src/debug.rs @@ -0,0 +1,394 @@ +//! Lua debugging interface. +//! +//! This module provides access to the Lua debug interface, allowing inspection of the call stack, +//! and function information. The main types are [`struct@Debug`] for accessing debug information +//! and [`HookTriggers`] for configuring debug hooks. + +use std::borrow::Cow; +use std::os::raw::c_int; + +use ffi::{lua_Debug, lua_State}; + +use crate::function::Function; +use crate::state::RawLua; +use crate::util::{StackGuard, assert_stack, linenumber_to_usize, ptr_to_lossy_str, ptr_to_str}; + +/// Contains information about currently executing Lua code. +/// +/// You may call the methods on this structure to retrieve information about the Lua code executing +/// at the specific level. Further information can be found in the Lua [documentation]. +/// +/// [documentation]: https://www.lua.org/manual/5.4/manual.html#lua_Debug +pub struct Debug<'a> { + state: *mut lua_State, + lua: &'a RawLua, + #[cfg_attr(not(feature = "luau"), allow(unused))] + level: c_int, + ar: *mut lua_Debug, +} + +impl<'a> Debug<'a> { + pub(crate) fn new(lua: &'a RawLua, level: c_int, ar: *mut lua_Debug) -> Self { + Debug { + state: lua.state(), + lua, + ar, + level, + } + } + + /// Returns the specific event that triggered the hook. + /// + /// For [Lua 5.1] [`DebugEvent::TailCall`] is used for return events to indicate a return + /// from a function that did a tail call. + /// + /// [Lua 5.1]: https://www.lua.org/manual/5.1/manual.html#pdf-LUA_HOOKTAILRET + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn event(&self) -> DebugEvent { + unsafe { + match (*self.ar).event { + ffi::LUA_HOOKCALL => DebugEvent::Call, + ffi::LUA_HOOKRET => DebugEvent::Ret, + ffi::LUA_HOOKTAILCALL => DebugEvent::TailCall, + ffi::LUA_HOOKLINE => DebugEvent::Line, + ffi::LUA_HOOKCOUNT => DebugEvent::Count, + event => DebugEvent::Unknown(event), + } + } + } + + /// Returns the function that is running at the given level. + /// + /// Corresponds to the `f` "what" mask. + pub fn function(&self) -> Function { + unsafe { + let _sg = StackGuard::new(self.state); + assert_stack(self.state, 1); + + #[cfg(not(feature = "luau"))] + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("f"), self.ar) != 0, + "lua_getinfo failed with `f`" + ); + #[cfg(feature = "luau")] + mlua_assert!( + ffi::lua_getinfo(self.state, self.level, cstr!("f"), self.ar) != 0, + "lua_getinfo failed with `f`" + ); + + ffi::lua_xmove(self.state, self.lua.ref_thread(), 1); + Function(self.lua.pop_ref_thread()) + } + } + + /// Corresponds to the `n` "what" mask. + pub fn names(&self) -> DebugNames<'_> { + unsafe { + #[cfg(not(feature = "luau"))] + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("n"), self.ar) != 0, + "lua_getinfo failed with `n`" + ); + #[cfg(feature = "luau")] + mlua_assert!( + ffi::lua_getinfo(self.state, self.level, cstr!("n"), self.ar) != 0, + "lua_getinfo failed with `n`" + ); + + DebugNames { + name: ptr_to_lossy_str((*self.ar).name), + #[cfg(not(feature = "luau"))] + name_what: match ptr_to_str((*self.ar).namewhat) { + Some("") => None, + val => val, + }, + #[cfg(feature = "luau")] + name_what: None, + } + } + } + + /// Corresponds to the `S` "what" mask. + pub fn source(&self) -> DebugSource<'_> { + unsafe { + #[cfg(not(feature = "luau"))] + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("S"), self.ar) != 0, + "lua_getinfo failed with `S`" + ); + #[cfg(feature = "luau")] + mlua_assert!( + ffi::lua_getinfo(self.state, self.level, cstr!("s"), self.ar) != 0, + "lua_getinfo failed with `s`" + ); + + DebugSource { + source: ptr_to_lossy_str((*self.ar).source), + #[cfg(not(feature = "luau"))] + short_src: ptr_to_lossy_str((*self.ar).short_src.as_ptr()), + #[cfg(feature = "luau")] + short_src: ptr_to_lossy_str((*self.ar).short_src), + line_defined: linenumber_to_usize((*self.ar).linedefined), + #[cfg(not(feature = "luau"))] + last_line_defined: linenumber_to_usize((*self.ar).lastlinedefined), + #[cfg(feature = "luau")] + last_line_defined: None, + what: ptr_to_str((*self.ar).what).unwrap_or("main"), + } + } + } + + /// Corresponds to the `l` "what" mask. Returns the current line. + pub fn current_line(&self) -> Option { + unsafe { + #[cfg(not(feature = "luau"))] + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("l"), self.ar) != 0, + "lua_getinfo failed with `l`" + ); + #[cfg(feature = "luau")] + mlua_assert!( + ffi::lua_getinfo(self.state, self.level, cstr!("l"), self.ar) != 0, + "lua_getinfo failed with `l`" + ); + + linenumber_to_usize((*self.ar).currentline) + } + } + + /// Corresponds to the `t` "what" mask. Returns true if the hook is in a function tail call, + /// false otherwise. + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))) + )] + pub fn is_tail_call(&self) -> bool { + unsafe { + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("t"), self.ar) != 0, + "lua_getinfo failed with `t`" + ); + (*self.ar).istailcall != 0 + } + } + + /// Corresponds to the `u` "what" mask. + pub fn stack(&self) -> DebugStack { + unsafe { + #[cfg(not(feature = "luau"))] + mlua_assert!( + ffi::lua_getinfo(self.state, cstr!("u"), self.ar) != 0, + "lua_getinfo failed with `u`" + ); + #[cfg(feature = "luau")] + mlua_assert!( + ffi::lua_getinfo(self.state, self.level, cstr!("au"), self.ar) != 0, + "lua_getinfo failed with `au`" + ); + + #[cfg(not(feature = "luau"))] + let stack = DebugStack { + num_upvalues: (*self.ar).nups as _, + #[cfg(not(any(feature = "lua51", feature = "luajit")))] + num_params: (*self.ar).nparams as _, + #[cfg(not(any(feature = "lua51", feature = "luajit")))] + is_vararg: (*self.ar).isvararg != 0, + }; + #[cfg(feature = "luau")] + let stack = DebugStack { + num_upvalues: (*self.ar).nupvals, + num_params: (*self.ar).nparams, + is_vararg: (*self.ar).isvararg != 0, + }; + stack + } + } +} + +/// Represents a specific event that triggered the hook. +#[cfg(not(feature = "luau"))] +#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DebugEvent { + Call, + Ret, + TailCall, + Line, + Count, + Unknown(c_int), +} + +/// Contains the name information of a function in the call stack. +/// +/// Returned by the [`Debug::names`] method. +#[derive(Clone, Debug)] +pub struct DebugNames<'a> { + /// A (reasonable) name of the function (`None` if the name cannot be found). + pub name: Option>, + /// Explains the `name` field (can be `global`/`local`/`method`/`field`/`upvalue`/etc). + /// + /// Always `None` for Luau. + pub name_what: Option<&'static str>, +} + +/// Contains the source information of a function in the call stack. +/// +/// Returned by the [`Debug::source`] method. +#[derive(Clone, Debug)] +pub struct DebugSource<'a> { + /// Source of the chunk that created the function. + pub source: Option>, + /// A "printable" version of `source`, to be used in error messages. + pub short_src: Option>, + /// The line number where the definition of the function starts. + pub line_defined: Option, + /// The line number where the definition of the function ends (not set by Luau). + pub last_line_defined: Option, + /// A string `Lua` if the function is a Lua function, `C` if it is a C function, `main` if it is + /// the main part of a chunk. + pub what: &'static str, +} + +/// Contains stack information about a function in the call stack. +/// +/// Returned by the [`Debug::stack`] method. +#[derive(Copy, Clone, Debug)] +pub struct DebugStack { + /// The number of upvalues of the function. + pub num_upvalues: u8, + /// The number of parameters of the function (always 0 for C). + #[cfg(any(not(any(feature = "lua51", feature = "luajit")), doc))] + #[cfg_attr(docsrs, doc(cfg(not(any(feature = "lua51", feature = "luajit")))))] + pub num_params: u8, + /// Whether the function is a variadic function (always true for C). + #[cfg(any(not(any(feature = "lua51", feature = "luajit")), doc))] + #[cfg_attr(docsrs, doc(cfg(not(any(feature = "lua51", feature = "luajit")))))] + pub is_vararg: bool, +} + +/// Determines when a hook function will be called by Lua. +#[cfg(not(feature = "luau"))] +#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] +#[derive(Clone, Copy, Debug, Default)] +pub struct HookTriggers { + /// Before a function call. + pub on_calls: bool, + /// When Lua returns from a function. + pub on_returns: bool, + /// Before executing a new line, or returning from a function call. + pub every_line: bool, + /// After a certain number of VM instructions have been executed. When set to `Some(count)`, + /// `count` is the number of VM instructions to execute before calling the hook. + /// + /// # Performance + /// + /// Setting this option to a low value can incur a very high overhead. + pub every_nth_instruction: Option, +} + +#[cfg(not(feature = "luau"))] +impl HookTriggers { + /// An instance of `HookTriggers` with `on_calls` trigger set. + pub const ON_CALLS: Self = HookTriggers::new().on_calls(); + + /// An instance of `HookTriggers` with `on_returns` trigger set. + pub const ON_RETURNS: Self = HookTriggers::new().on_returns(); + + /// An instance of `HookTriggers` with `every_line` trigger set. + pub const EVERY_LINE: Self = HookTriggers::new().every_line(); + + /// Returns a new instance of `HookTriggers` with all triggers disabled. + pub const fn new() -> Self { + HookTriggers { + on_calls: false, + on_returns: false, + every_line: false, + every_nth_instruction: None, + } + } + + /// Returns an instance of `HookTriggers` with [`on_calls`] trigger set. + /// + /// [`on_calls`]: #structfield.on_calls + pub const fn on_calls(mut self) -> Self { + self.on_calls = true; + self + } + + /// Returns an instance of `HookTriggers` with [`on_returns`] trigger set. + /// + /// [`on_returns`]: #structfield.on_returns + pub const fn on_returns(mut self) -> Self { + self.on_returns = true; + self + } + + /// Returns an instance of `HookTriggers` with [`every_line`] trigger set. + /// + /// [`every_line`]: #structfield.every_line + pub const fn every_line(mut self) -> Self { + self.every_line = true; + self + } + + /// Returns an instance of `HookTriggers` with [`every_nth_instruction`] trigger set. + /// + /// [`every_nth_instruction`]: #structfield.every_nth_instruction + pub const fn every_nth_instruction(mut self, n: u32) -> Self { + self.every_nth_instruction = Some(n); + self + } + + // Compute the mask to pass to `lua_sethook`. + #[cfg(not(feature = "luau"))] + pub(crate) const fn mask(&self) -> c_int { + let mut mask: c_int = 0; + if self.on_calls { + mask |= ffi::LUA_MASKCALL + } + if self.on_returns { + mask |= ffi::LUA_MASKRET + } + if self.every_line { + mask |= ffi::LUA_MASKLINE + } + if self.every_nth_instruction.is_some() { + mask |= ffi::LUA_MASKCOUNT + } + mask + } + + // Returns the `count` parameter to pass to `lua_sethook`, if applicable. Otherwise, zero is + // returned. + #[cfg(not(feature = "luau"))] + pub(crate) const fn count(&self) -> c_int { + match self.every_nth_instruction { + Some(n) => n as c_int, + None => 0, + } + } +} + +#[cfg(not(feature = "luau"))] +impl std::ops::BitOr for HookTriggers { + type Output = Self; + + fn bitor(mut self, rhs: Self) -> Self::Output { + self.on_calls |= rhs.on_calls; + self.on_returns |= rhs.on_returns; + self.every_line |= rhs.every_line; + if self.every_nth_instruction.is_none() && rhs.every_nth_instruction.is_some() { + self.every_nth_instruction = rhs.every_nth_instruction; + } + self + } +} + +#[cfg(not(feature = "luau"))] +impl std::ops::BitOrAssign for HookTriggers { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } +} diff --git a/src/error.rs b/src/error.rs index e50a2b2b..0664618a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,7 @@ -#![allow(clippy::wrong_self_convention)] +//! Lua error handling. +//! +//! This module provides the [`Error`] type returned by all fallible `mlua` operations, together +//! with extension traits for adapting Rust errors for use within Lua. use std::error::Error as StdError; use std::fmt; @@ -6,9 +9,16 @@ use std::io::Error as IoError; use std::net::AddrParseError; use std::result::Result as StdResult; use std::str::Utf8Error; -use std::string::String as StdString; use std::sync::Arc; +use crate::private::Sealed; + +#[cfg(feature = "error-send")] +type DynStdError = dyn StdError + Send + Sync; + +#[cfg(not(feature = "error-send"))] +type DynStdError = dyn StdError; + /// Error type returned by `mlua` methods. #[derive(Debug, Clone)] #[non_exhaustive] @@ -16,7 +26,7 @@ pub enum Error { /// Syntax error while parsing Lua source code. SyntaxError { /// The error message as returned by Lua. - message: StdString, + message: String, /// `true` if the error can likely be fixed by appending more input to the source code. /// /// This is useful for implementing REPLs as they can query the user for more input if this @@ -28,30 +38,25 @@ pub enum Error { /// The Lua VM returns this error when a builtin operation is performed on incompatible types. /// Among other things, this includes invoking operators on wrong types (such as calling or /// indexing a `nil` value). - RuntimeError(StdString), + RuntimeError(String), /// Lua memory error, aka `LUA_ERRMEM` /// /// The Lua VM returns this error when the allocator does not return the requested memory, aka /// it is an out-of-memory error. - MemoryError(StdString), + MemoryError(String), /// Lua garbage collector error, aka `LUA_ERRGCMM`. /// /// The Lua VM returns this error when there is an error running a `__gc` metamethod. #[cfg(any(feature = "lua53", feature = "lua52", doc))] #[cfg_attr(docsrs, doc(cfg(any(feature = "lua53", feature = "lua52"))))] - GarbageCollectorError(StdString), + GarbageCollectorError(String), /// Potentially unsafe action in safe mode. - SafetyError(StdString), - /// Setting memory limit is not available. + SafetyError(String), + /// Memory control is not available. /// /// This error can only happen when Lua state was not created by us and does not have the /// custom allocator attached. - MemoryLimitNotAvailable, - /// Main thread is not available. - /// - /// This error can only happen in Lua5.1/LuaJIT module mode, when module loaded within a coroutine. - /// These Lua versions does not have `LUA_RIDX_MAINTHREAD` registry key. - MainThreadNotAvailable, + MemoryControlNotAvailable, /// A mutable callback has triggered Lua code that has called the same mutable callback again. /// /// This is an error because a mutable callback can only be borrowed mutably once. @@ -66,40 +71,47 @@ pub enum Error { /// /// Due to the way `mlua` works, it should not be directly possible to run out of stack space /// during normal use. The only way that this error can be triggered is if a `Function` is - /// called with a huge number of arguments, or a rust callback returns a huge number of return + /// called with a huge number of arguments, or a Rust callback returns a huge number of return /// values. StackError, - /// Too many arguments to `Function::bind` + /// Too many arguments to [`Function::bind`]. + /// + /// [`Function::bind`]: crate::Function::bind BindError, - /// A Rust value could not be converted to a Lua value. - ToLuaConversionError { - /// Name of the Rust type that could not be converted. - from: &'static str, - /// Name of the Lua type that could not be created. - to: &'static str, - /// A message indicating why the conversion failed in more detail. - message: Option, + /// Bad argument received from Lua (usually when calling a function). + /// + /// This error can help to identify the argument that caused the error + /// (which is stored in the corresponding field). + BadArgument { + /// Function that was called. + to: Option, + /// Argument position (usually starts from 1). + pos: usize, + /// Argument name. + name: Option, + /// Underlying error returned when converting argument to a Lua value. + cause: Arc, }, /// A Lua value could not be converted to the expected Rust type. FromLuaConversionError { /// Name of the Lua type that could not be converted. from: &'static str, /// Name of the Rust type that could not be created. - to: &'static str, + to: String, /// A string containing more detailed error information. - message: Option, + message: Option, }, - /// [`Thread::resume`] was called on an inactive coroutine. + /// [`Thread::resume`] was called on an unresumable coroutine. /// - /// A coroutine is inactive if its main function has returned or if an error has occurred inside - /// the coroutine. + /// A coroutine is unresumable if its main function has returned or if an error has occurred + /// inside the coroutine. Already running coroutines are also marked as unresumable. /// /// [`Thread::status`] can be used to check if the coroutine can be resumed without causing this /// error. /// /// [`Thread::resume`]: crate::Thread::resume /// [`Thread::status`]: crate::Thread::status - CoroutineInactive, + CoroutineUnresumable, /// An [`AnyUserData`] is not the expected type in a borrow. /// /// This error can only happen when manually using [`AnyUserData`], or when implementing @@ -116,7 +128,7 @@ pub enum Error { /// /// [`AnyUserData`]: crate::AnyUserData UserDataDestructed, - /// An [`AnyUserData`] immutable borrow failed because it is already borrowed mutably. + /// An [`AnyUserData`] immutable borrow failed. /// /// This error can occur when a method on a [`UserData`] type calls back into Lua, which then /// tries to call a method on the same [`UserData`] type. Consider restructuring your API to @@ -125,7 +137,7 @@ pub enum Error { /// [`AnyUserData`]: crate::AnyUserData /// [`UserData`]: crate::UserData UserDataBorrowError, - /// An [`AnyUserData`] mutable borrow failed because it is already borrowed. + /// An [`AnyUserData`] mutable borrow failed. /// /// This error can occur when a method on a [`UserData`] type calls back into Lua, which then /// tries to call a method on the same [`UserData`] type. Consider restructuring your API to @@ -137,14 +149,17 @@ pub enum Error { /// A [`MetaMethod`] operation is restricted (typically for `__gc` or `__metatable`). /// /// [`MetaMethod`]: crate::MetaMethod - MetaMethodRestricted(StdString), + MetaMethodRestricted(String), /// A [`MetaMethod`] (eg. `__index` or `__newindex`) has invalid type. /// /// [`MetaMethod`]: crate::MetaMethod MetaMethodTypeError { - method: StdString, + /// Name of the metamethod. + method: String, + /// Passed value type. type_name: &'static str, - message: Option, + /// A string containing more detailed error information. + message: Option, }, /// A [`RegistryKey`] produced from a different Lua state was used. /// @@ -153,7 +168,7 @@ pub enum Error { /// A Rust callback returned `Err`, raising the contained `Error` as a Lua error. CallbackError { /// Lua call stack backtrace. - traceback: StdString, + traceback: String, /// Original error returned by the Rust code. cause: Arc, }, @@ -163,13 +178,13 @@ pub enum Error { /// and returned again. PreviouslyResumedPanic, /// Serialization error. - #[cfg(feature = "serialize")] - #[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] - SerializeError(StdString), + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + SerializeError(String), /// Deserialization error. - #[cfg(feature = "serialize")] - #[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] - DeserializeError(StdString), + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + DeserializeError(String), /// A custom error. /// /// This can be used for returning user-defined errors from callbacks. @@ -177,7 +192,14 @@ pub enum Error { /// Returning `Err(ExternalError(...))` from a Rust callback will raise the error as a Lua /// error. The Rust code that originally invoked the Lua code then receives a `CallbackError`, /// from which the original error (and a stack traceback) can be recovered. - ExternalError(Arc), + ExternalError(Arc), + /// An error with additional context. + WithContext { + /// A string containing additional context. + context: String, + /// Underlying error. + cause: Arc, + }, } /// A specialized `Result` type used by `mlua`'s API. @@ -186,24 +208,21 @@ pub type Result = StdResult; #[cfg(not(tarpaulin_include))] impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::SyntaxError { ref message, .. } => write!(fmt, "syntax error: {}", message), - Error::RuntimeError(ref msg) => write!(fmt, "runtime error: {}", msg), - Error::MemoryError(ref msg) => { - write!(fmt, "memory error: {}", msg) + match self { + Error::SyntaxError { message, .. } => write!(fmt, "syntax error: {message}"), + Error::RuntimeError(msg) => write!(fmt, "runtime error: {msg}"), + Error::MemoryError(msg) => { + write!(fmt, "memory error: {msg}") } #[cfg(any(feature = "lua53", feature = "lua52"))] - Error::GarbageCollectorError(ref msg) => { - write!(fmt, "garbage collector error: {}", msg) + Error::GarbageCollectorError(msg) => { + write!(fmt, "garbage collector error: {msg}") } - Error::SafetyError(ref msg) => { - write!(fmt, "safety error: {}", msg) - }, - Error::MemoryLimitNotAvailable => { - write!(fmt, "setting memory limit is not available") + Error::SafetyError(msg) => { + write!(fmt, "safety error: {msg}") } - Error::MainThreadNotAvailable => { - write!(fmt, "main thread is not available in Lua 5.1") + Error::MemoryControlNotAvailable => { + write!(fmt, "memory control is not available") } Error::RecursiveMutCallback => write!(fmt, "mutable callback called recursively"), Error::CallbackDestructed => write!( @@ -214,48 +233,57 @@ impl fmt::Display for Error { fmt, "out of Lua stack, too many arguments to a Lua function or too many return values from a callback" ), - Error::BindError => write!( - fmt, - "too many arguments to Function::bind" - ), - Error::ToLuaConversionError { from, to, ref message } => { - write!(fmt, "error converting {} to Lua {}", from, to)?; - match *message { - None => Ok(()), - Some(ref message) => write!(fmt, " ({})", message), + Error::BindError => write!(fmt, "too many arguments to Function::bind"), + Error::BadArgument { to, pos, name, cause } => { + if let Some(name) = name { + write!(fmt, "bad argument `{name}`")?; + } else { + write!(fmt, "bad argument #{pos}")?; + } + if let Some(to) = to { + write!(fmt, " to `{to}`")?; } + write!(fmt, ": {cause}") } - Error::FromLuaConversionError { from, to, ref message } => { - write!(fmt, "error converting Lua {} to {}", from, to)?; - match *message { + Error::FromLuaConversionError { from, to, message } => { + write!(fmt, "error converting Lua {from} to {to}")?; + match message { None => Ok(()), - Some(ref message) => write!(fmt, " ({})", message), + Some(message) => write!(fmt, " ({message})"), } } - Error::CoroutineInactive => write!(fmt, "cannot resume inactive coroutine"), + Error::CoroutineUnresumable => write!(fmt, "coroutine is non-resumable"), Error::UserDataTypeMismatch => write!(fmt, "userdata is not expected type"), Error::UserDataDestructed => write!(fmt, "userdata has been destructed"), - Error::UserDataBorrowError => write!(fmt, "userdata already mutably borrowed"), - Error::UserDataBorrowMutError => write!(fmt, "userdata already borrowed"), - Error::MetaMethodRestricted(ref method) => write!(fmt, "metamethod {} is restricted", method), - Error::MetaMethodTypeError { ref method, type_name, ref message } => { - write!(fmt, "metamethod {} has unsupported type {}", method, type_name)?; - match *message { + Error::UserDataBorrowError => write!(fmt, "error borrowing userdata"), + Error::UserDataBorrowMutError => write!(fmt, "error mutably borrowing userdata"), + Error::MetaMethodRestricted(method) => write!(fmt, "metamethod {method} is restricted"), + Error::MetaMethodTypeError { + method, + type_name, + message, + } => { + write!(fmt, "metamethod {method} has unsupported type {type_name}")?; + match message { None => Ok(()), - Some(ref message) => write!(fmt, " ({})", message), + Some(message) => write!(fmt, " ({message})"), } } Error::MismatchedRegistryKey => { write!(fmt, "RegistryKey used from different Lua state") } - Error::CallbackError { ref cause, ref traceback } => { - writeln!(fmt, "callback error")?; + Error::CallbackError { cause, traceback } => { // Trace errors down to the root let (mut cause, mut full_traceback) = (cause, None); - while let Error::CallbackError { cause: ref cause2, traceback: ref traceback2 } = **cause { + while let Error::CallbackError { + cause: cause2, + traceback: traceback2, + } = &**cause + { cause = cause2; full_traceback = Some(traceback2); } + writeln!(fmt, "{cause}")?; if let Some(full_traceback) = full_traceback { let traceback = traceback.trim_start_matches("stack traceback:"); let traceback = traceback.trim_start().trim_end(); @@ -269,95 +297,270 @@ impl fmt::Display for Error { } else { writeln!(fmt, "{}", traceback.trim_end())?; } - write!(fmt, "caused by: {}", cause) + Ok(()) } Error::PreviouslyResumedPanic => { write!(fmt, "previously resumed panic returned again") } - #[cfg(feature = "serialize")] - Error::SerializeError(ref err) => { - write!(fmt, "serialize error: {}", err) - }, - #[cfg(feature = "serialize")] - Error::DeserializeError(ref err) => { - write!(fmt, "deserialize error: {}", err) - }, - Error::ExternalError(ref err) => write!(fmt, "{}", err), + #[cfg(feature = "serde")] + Error::SerializeError(err) => { + write!(fmt, "serialize error: {err}") + } + #[cfg(feature = "serde")] + Error::DeserializeError(err) => { + write!(fmt, "deserialize error: {err}") + } + Error::ExternalError(err) => err.fmt(fmt), + Error::WithContext { context, cause } => { + writeln!(fmt, "{context}")?; + write!(fmt, "{cause}") + } } } } impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { + match self { // An error type with a source error should either return that error via source or // include that source's error message in its own Display output, but never both. // https://blog.rust-lang.org/inside-rust/2021/07/01/What-the-error-handling-project-group-is-working-towards.html - // Given that we include source to fmt::Display implementation for `CallbackError`, this call returns nothing. + // Given that we include source to fmt::Display implementation for `CallbackError`, this call + // returns nothing. Error::CallbackError { .. } => None, - Error::ExternalError(ref err) => err.source(), + Error::ExternalError(err) => err.source(), + Error::WithContext { cause, .. } => Self::source(cause), _ => None, } } } impl Error { - pub fn external>>(err: T) -> Error { - Error::ExternalError(err.into().into()) + /// Creates a new `RuntimeError` with the given message. + #[inline] + pub fn runtime(message: S) -> Self { + Error::RuntimeError(message.to_string()) + } + + /// Wraps an external error object. + #[inline] + pub fn external>>(err: T) -> Self { + let boxed = err.into(); + match boxed.downcast::() { + Ok(err) => *err, + Err(boxed) => Error::ExternalError(boxed.into()), + } + } + + /// Attempts to downcast the external error object to a concrete type by reference. + pub fn downcast_ref(&self) -> Option<&T> + where + T: StdError + 'static, + { + match self { + Error::ExternalError(err) => err.downcast_ref(), + Error::WithContext { cause, .. } => Self::downcast_ref(cause), + _ => None, + } + } + + /// An iterator over the chain of nested errors wrapped by this Error. + pub fn chain(&self) -> impl Iterator { + Chain { + root: self, + current: None, + } + } + + /// Returns the parent of this error. + #[doc(hidden)] + pub fn parent(&self) -> Option<&Error> { + match self { + Error::CallbackError { cause, .. } => Some(cause.as_ref()), + Error::WithContext { cause, .. } => Some(cause.as_ref()), + _ => None, + } + } + + pub(crate) fn bad_self_argument(to: &str, cause: Error) -> Self { + Error::BadArgument { + to: Some(to.to_string()), + pos: 1, + name: Some("self".to_string()), + cause: Arc::new(cause), + } + } + + #[inline] + pub(crate) fn from_lua_conversion( + from: &'static str, + to: impl ToString, + message: impl Into>, + ) -> Self { + Error::FromLuaConversionError { + from, + to: to.to_string(), + message: message.into(), + } } } +/// Trait for converting [`std::error::Error`] into Lua [`Error`]. pub trait ExternalError { - fn to_lua_err(self) -> Error; + fn into_lua_err(self) -> Error; } -impl>> ExternalError for E { - fn to_lua_err(self) -> Error { +impl>> ExternalError for E { + fn into_lua_err(self) -> Error { Error::external(self) } } +/// Trait for converting [`std::result::Result`] into Lua [`Result`]. pub trait ExternalResult { - fn to_lua_err(self) -> Result; + fn into_lua_err(self) -> Result; } impl ExternalResult for StdResult where E: ExternalError, { - fn to_lua_err(self) -> Result { - self.map_err(|e| e.to_lua_err()) + fn into_lua_err(self) -> Result { + self.map_err(|e| e.into_lua_err()) + } +} + +/// Provides the `context` method for [`Error`] and `Result`. +pub trait ErrorContext: Sealed { + /// Wraps the error value with additional context. + fn context(self, context: C) -> Self; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: impl FnOnce(&Error) -> C) -> Self; +} + +impl ErrorContext for Error { + fn context(self, context: C) -> Self { + let context = context.to_string(); + match self { + Error::WithContext { cause, .. } => Error::WithContext { context, cause }, + _ => Error::WithContext { + context, + cause: Arc::new(self), + }, + } + } + + fn with_context(self, f: impl FnOnce(&Error) -> C) -> Self { + let context = f(&self).to_string(); + match self { + Error::WithContext { cause, .. } => Error::WithContext { context, cause }, + _ => Error::WithContext { + context, + cause: Arc::new(self), + }, + } + } +} + +impl ErrorContext for Result { + fn context(self, context: C) -> Self { + self.map_err(|err| err.context(context)) + } + + fn with_context(self, f: impl FnOnce(&Error) -> C) -> Self { + self.map_err(|err| err.with_context(f)) } } -impl std::convert::From for Error { +impl From for Error { fn from(err: AddrParseError) -> Self { Error::external(err) } } -impl std::convert::From for Error { +impl From for Error { fn from(err: IoError) -> Self { Error::external(err) } } -impl std::convert::From for Error { +impl From for Error { fn from(err: Utf8Error) -> Self { Error::external(err) } } -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] impl serde::ser::Error for Error { fn custom(msg: T) -> Self { Self::SerializeError(msg.to_string()) } } -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] impl serde::de::Error for Error { fn custom(msg: T) -> Self { Self::DeserializeError(msg.to_string()) } } + +#[cfg(feature = "anyhow")] +impl From for Error { + fn from(err: anyhow::Error) -> Self { + match err.downcast::() { + Ok(err) => err, + Err(err) => Error::external(err), + } + } +} + +struct Chain<'a> { + root: &'a Error, + current: Option<&'a (dyn StdError + 'static)>, +} + +impl<'a> Iterator for Chain<'a> { + type Item = &'a (dyn StdError + 'static); + + fn next(&mut self) -> Option { + loop { + let error: Option<&dyn StdError> = match self.current { + None => { + self.current = Some(self.root); + self.current + } + Some(current) => match current.downcast_ref::()? { + Error::BadArgument { cause, .. } + | Error::CallbackError { cause, .. } + | Error::WithContext { cause, .. } => { + self.current = Some(&**cause); + self.current + } + Error::ExternalError(err) => { + self.current = Some(&**err); + self.current + } + _ => None, + }, + }; + + // Skip `ExternalError` as it only wraps the underlying error + // without meaningful context + if let Some(Error::ExternalError(_)) = error?.downcast_ref::() { + continue; + } + + return self.current; + } + } +} + +#[cfg(test)] +mod assertions { + #[cfg(not(feature = "error-send"))] + static_assertions::assert_not_impl_any!(super::Error: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(super::Error: Send, Sync); +} diff --git a/src/ffi/lua52/lualib.rs b/src/ffi/lua52/lualib.rs deleted file mode 100644 index 9ffba45b..00000000 --- a/src/ffi/lua52/lualib.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! Contains definitions from `lualib.h`. - -use std::os::raw::c_int; - -use super::lua::lua_State; - -pub const LUA_COLIBNAME: &str = "coroutine"; -pub const LUA_TABLIBNAME: &str = "table"; -pub const LUA_IOLIBNAME: &str = "io"; -pub const LUA_OSLIBNAME: &str = "os"; -pub const LUA_STRLIBNAME: &str = "string"; -pub const LUA_BITLIBNAME: &str = "bit32"; -pub const LUA_MATHLIBNAME: &str = "math"; -pub const LUA_DBLIBNAME: &str = "debug"; -pub const LUA_LOADLIBNAME: &str = "package"; - -extern "C" { - pub fn luaopen_base(L: *mut lua_State) -> c_int; - pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; - pub fn luaopen_table(L: *mut lua_State) -> c_int; - pub fn luaopen_io(L: *mut lua_State) -> c_int; - pub fn luaopen_os(L: *mut lua_State) -> c_int; - pub fn luaopen_string(L: *mut lua_State) -> c_int; - pub fn luaopen_bit32(L: *mut lua_State) -> c_int; - pub fn luaopen_math(L: *mut lua_State) -> c_int; - pub fn luaopen_debug(L: *mut lua_State) -> c_int; - pub fn luaopen_package(L: *mut lua_State) -> c_int; - - // open all builtin libraries - pub fn luaL_openlibs(L: *mut lua_State); -} diff --git a/src/ffi/lua53/compat.rs b/src/ffi/lua53/compat.rs deleted file mode 100644 index 7150b713..00000000 --- a/src/ffi/lua53/compat.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! MLua compatibility layer for Lua 5.2 - -use std::os::raw::c_int; - -use super::lua::*; - -#[inline(always)] -pub unsafe fn lua_resume( - L: *mut lua_State, - from: *mut lua_State, - narg: c_int, - nres: *mut c_int, -) -> c_int { - let ret = lua_resume_(L, from, narg); - if (ret == LUA_OK || ret == LUA_YIELD) && !(nres.is_null()) { - *nres = lua_gettop(L); - } - ret -} diff --git a/src/ffi/luau/luacode.rs b/src/ffi/luau/luacode.rs deleted file mode 100644 index 571db4c2..00000000 --- a/src/ffi/luau/luacode.rs +++ /dev/null @@ -1,39 +0,0 @@ -//! Contains definitions from `luacode.h`. - -use std::os::raw::{c_char, c_int, c_void}; -use std::slice; - -#[repr(C)] -pub struct lua_CompileOptions { - pub optimizationLevel: c_int, - pub debugLevel: c_int, - pub coverageLevel: c_int, - pub vectorLib: *const c_char, - pub vectorCtor: *const c_char, - pub mutableGlobals: *mut *const c_char, -} - -extern "C" { - #[link_name = "luau_compile"] - pub fn luau_compile_( - source: *const c_char, - size: usize, - options: *mut lua_CompileOptions, - outsize: *mut usize, - ) -> *mut c_char; - - fn free(p: *mut c_void); -} - -pub unsafe fn luau_compile(source: &[u8], mut options: lua_CompileOptions) -> Vec { - let mut outsize = 0; - let data_ptr = luau_compile_( - source.as_ptr() as *const c_char, - source.len(), - &mut options, - &mut outsize, - ); - let data = slice::from_raw_parts(data_ptr as *mut u8, outsize).to_vec(); - free(data_ptr as *mut c_void); - data -} diff --git a/src/ffi/luau/lualib.rs b/src/ffi/luau/lualib.rs deleted file mode 100644 index 92b8642c..00000000 --- a/src/ffi/luau/lualib.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Contains definitions from `lualib.h`. - -use std::os::raw::c_int; - -use super::lua::lua_State; - -pub const LUA_COLIBNAME: &str = "coroutine"; -pub const LUA_TABLIBNAME: &str = "table"; -pub const LUA_OSLIBNAME: &str = "os"; -pub const LUA_STRLIBNAME: &str = "string"; -pub const LUA_BITLIBNAME: &str = "bit32"; -pub const LUA_UTF8LIBNAME: &str = "utf8"; -pub const LUA_MATHLIBNAME: &str = "math"; -pub const LUA_DBLIBNAME: &str = "debug"; - -extern "C" { - pub fn luaopen_base(L: *mut lua_State) -> c_int; - pub fn luaopen_coroutine(L: *mut lua_State) -> c_int; - pub fn luaopen_table(L: *mut lua_State) -> c_int; - pub fn luaopen_os(L: *mut lua_State) -> c_int; - pub fn luaopen_string(L: *mut lua_State) -> c_int; - pub fn luaopen_bit32(L: *mut lua_State) -> c_int; - pub fn luaopen_utf8(L: *mut lua_State) -> c_int; - pub fn luaopen_math(L: *mut lua_State) -> c_int; - pub fn luaopen_debug(L: *mut lua_State) -> c_int; - - // open all builtin libraries - pub fn luaL_openlibs(L: *mut lua_State); -} diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs deleted file mode 100644 index 09164b9b..00000000 --- a/src/ffi/mod.rs +++ /dev/null @@ -1,101 +0,0 @@ -//! Low level bindings to Lua 5.4/5.3/5.2/5.1 including LuaJIT. - -#![allow(non_camel_case_types, non_snake_case, dead_code)] - -use std::os::raw::c_int; - -#[cfg(feature = "lua54")] -pub use lua54::*; - -#[cfg(feature = "lua53")] -pub use lua53::*; - -#[cfg(feature = "lua52")] -pub use lua52::*; - -#[cfg(any(feature = "lua51", feature = "luajit"))] -pub use lua51::*; - -#[cfg(feature = "luau")] -pub use luau::*; - -#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] -pub const LUA_MAX_UPVALUES: c_int = 255; - -#[cfg(any(feature = "lua51", all(feature = "luajit", not(feature = "vendored"))))] -pub const LUA_MAX_UPVALUES: c_int = 60; - -#[cfg(all(feature = "luajit", feature = "vendored"))] -pub const LUA_MAX_UPVALUES: c_int = 120; - -#[cfg(feature = "luau")] -pub const LUA_MAX_UPVALUES: c_int = 200; - -// I believe `luaL_traceback` < 5.4 requires this much free stack to not error. -// 5.4 uses `luaL_Buffer` -pub const LUA_TRACEBACK_STACK: c_int = 11; - -// The minimum alignment guaranteed by the architecture. This value is used to -// add fast paths for low alignment values. -// Copied from https://github.com/rust-lang/rust/blob/master/library/std/src/sys/common/alloc.rs -#[cfg(all(any( - target_arch = "x86", - target_arch = "arm", - target_arch = "mips", - target_arch = "powerpc", - target_arch = "powerpc64", - target_arch = "sparc", - target_arch = "asmjs", - target_arch = "wasm32", - target_arch = "hexagon", - all(target_arch = "riscv32", not(target_os = "espidf")), - all(target_arch = "xtensa", not(target_os = "espidf")), -)))] -pub const SYS_MIN_ALIGN: usize = 8; -#[cfg(all(any( - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "mips64", - target_arch = "s390x", - target_arch = "sparc64", - target_arch = "riscv64", - target_arch = "wasm64", -)))] -pub const SYS_MIN_ALIGN: usize = 16; -// The allocator on the esp-idf platform guarentees 4 byte alignment. -#[cfg(all(any( - all(target_arch = "riscv32", target_os = "espidf"), - all(target_arch = "xtensa", target_os = "espidf"), -)))] -pub const SYS_MIN_ALIGN: usize = 4; - -// Hack to avoid stripping a few unused Lua symbols that could be imported -// by C modules in unsafe mode -#[cfg(not(feature = "luau"))] -pub(crate) fn keep_lua_symbols() { - let mut symbols: Vec<*const extern "C" fn()> = Vec::new(); - symbols.push(lua_atpanic as _); - symbols.push(lua_isuserdata as _); - symbols.push(lua_tocfunction as _); - symbols.push(luaL_loadstring as _); - symbols.push(luaL_openlibs as _); - if cfg!(any(feature = "lua54", feature = "lua53", feature = "lua52")) { - symbols.push(lua_getglobal as _); - symbols.push(lua_setglobal as _); - } -} - -#[cfg(feature = "lua54")] -pub mod lua54; - -#[cfg(feature = "lua53")] -pub mod lua53; - -#[cfg(feature = "lua52")] -pub mod lua52; - -#[cfg(any(feature = "lua51", feature = "luajit"))] -pub mod lua51; - -#[cfg(feature = "luau")] -pub mod luau; diff --git a/src/function.rs b/src/function.rs index 541627f6..234bbe23 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,35 +1,155 @@ -use std::mem; -use std::os::raw::c_int; -use std::ptr; +//! Lua function handling. +//! +//! This module provides types for working with Lua functions from Rust, including +//! both Lua-defined functions and native Rust callbacks. +//! +//! # Calling Functions +//! +//! Use [`Function::call`] to invoke a Lua function synchronously: +//! +//! ``` +//! # use mlua::{Function, Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! +//! // Get a built-in function +//! let print: Function = lua.globals().get("print")?; +//! print.call::<()>("Hello from Rust!")?; +//! +//! // Call a function that returns values +//! let tonumber: Function = lua.globals().get("tonumber")?; +//! let n: i32 = tonumber.call("42")?; +//! assert_eq!(n, 42); +//! # Ok(()) +//! # } +//! ``` +//! +//! For asynchronous execution, use `Function::call_async` (requires `async` feature): +//! +//! ```ignore +//! let result: String = my_async_func.call_async(args).await?; +//! ``` +//! +//! # Creating Functions +//! +//! Functions can be created from Rust closures using [`Lua::create_function`]: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! +//! let greet = lua.create_function(|_, name: String| { +//! Ok(format!("Hello, {}!", name)) +//! })?; +//! +//! lua.globals().set("greet", greet)?; +//! let result: String = lua.load(r#"greet("World")"#).eval()?; +//! assert_eq!(result, "Hello, World!"); +//! # Ok(()) +//! # } +//! ``` +//! +//! For simpler cases, use [`Function::wrap`] or [`Function::wrap_raw`] to convert a Rust function +//! directly: +//! +//! ``` +//! # use mlua::{Function, Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! +//! fn add(a: i32, b: i32) -> i32 { a + b } +//! +//! lua.globals().set("add", Function::wrap_raw(add))?; +//! let sum: i32 = lua.load("add(2, 3)").eval()?; +//! assert_eq!(sum, 5); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Function Environments +//! +//! Lua functions have an associated environment table that determines how global +//! variables are resolved. Use [`Function::environment`] and [`Function::set_environment`] +//! to inspect or modify this environment. -use crate::error::{Error, Result}; -use crate::ffi; -use crate::types::LuaRef; +use std::cell::RefCell; +use std::os::raw::{c_int, c_void}; +use std::result::Result as StdResult; +use std::{mem, ptr, slice}; + +use crate::error::{Error, ExternalError, ExternalResult, Result}; +use crate::state::Lua; +use crate::table::Table; +use crate::traits::{FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::types::{Callback, LuaType, MaybeSend, ValueRef}; use crate::util::{ - assert_stack, check_stack, error_traceback, pop_error, ptr_to_cstr_bytes, StackGuard, + StackGuard, assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, }; -use crate::value::{FromLuaMulti, ToLuaMulti}; +use crate::value::Value; #[cfg(feature = "async")] -use {futures_core::future::LocalBoxFuture, futures_util::future}; +use { + crate::thread::AsyncThread, + crate::types::AsyncCallback, + std::future::{self, Future}, + std::pin::{Pin, pin}, + std::task::{Context, Poll}, +}; /// Handle to an internal Lua function. -#[derive(Clone, Debug)] -pub struct Function<'lua>(pub(crate) LuaRef<'lua>); +#[derive(Clone, Debug, PartialEq)] +pub struct Function(pub(crate) ValueRef); +/// Contains information about a function. +/// +/// Please refer to the [`Lua Debug Interface`] for more information. +/// +/// [`Lua Debug Interface`]: https://www.lua.org/manual/5.4/manual.html#4.7 #[derive(Clone, Debug)] +#[non_exhaustive] pub struct FunctionInfo { - pub name: Option>, - pub name_what: Option>, - pub what: Option>, - pub source: Option>, - pub short_src: Option>, + /// A (reasonable) name of the function (`None` if the name cannot be found). + pub name: Option, + /// Explains the `name` field (can be `global`/`local`/`method`/`field`/`upvalue`/etc). + /// + /// Always `None` for Luau. + pub name_what: Option<&'static str>, + /// A string `Lua` if the function is a Lua function, `C` if it is a C function, `main` if it is + /// the main part of a chunk. + pub what: &'static str, + /// Source of the chunk that created the function. + pub source: Option, + /// A "printable" version of `source`, to be used in error messages. + pub short_src: Option, + /// The line number where the definition of the function starts. + pub line_defined: Option, + /// The line number where the definition of the function ends (not set by Luau). + pub last_line_defined: Option, + /// The number of upvalues of the function. + pub num_upvalues: u8, + /// The number of parameters of the function (always 0 for C). + #[cfg(any(not(any(feature = "lua51", feature = "luajit")), doc))] + #[cfg_attr(docsrs, doc(cfg(not(any(feature = "lua51", feature = "luajit")))))] + pub num_params: u8, + /// Whether the function is a variadic function (always true for C). + #[cfg(any(not(any(feature = "lua51", feature = "luajit")), doc))] + #[cfg_attr(docsrs, doc(cfg(not(any(feature = "lua51", feature = "luajit")))))] + pub is_vararg: bool, +} + +/// Luau function coverage snapshot. +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CoverageInfo { + pub function: Option, pub line_defined: i32, - #[cfg(not(feature = "luau"))] - pub last_line_defined: i32, + pub depth: i32, + pub hits: Vec, } -impl<'lua> Function<'lua> { +impl Function { /// Calls the function, passing `args` as function arguments. /// /// The function's return values are converted to the generic type `R`. @@ -46,7 +166,7 @@ impl<'lua> Function<'lua> { /// /// let tostring: Function = globals.get("tostring")?; /// - /// assert_eq!(tostring.call::<_, String>(123)?, "123"); + /// assert_eq!(tostring.call::(123)?, "123"); /// /// # Ok(()) /// # } @@ -65,89 +185,76 @@ impl<'lua> Function<'lua> { /// end /// "#).eval()?; /// - /// assert_eq!(sum.call::<_, u32>((3, 4))?, 3 + 4); + /// assert_eq!(sum.call::((3, 4))?, 3 + 4); /// /// # Ok(()) /// # } /// ``` - pub fn call, R: FromLuaMulti<'lua>>(&self, args: A) -> Result { - let lua = self.0.lua; - - let mut args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - - let results = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, nargs + 3)?; + pub fn call(&self, args: impl IntoLuaMulti) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 2)?; - ffi::lua_pushcfunction(lua.state, error_traceback); - let stack_start = ffi::lua_gettop(lua.state); + // Push error handler + lua.push_error_traceback(); + let stack_start = ffi::lua_gettop(state); + // Push function and the arguments lua.push_ref(&self.0); - for arg in args.drain_all() { - lua.push_value(arg)?; - } - let ret = ffi::lua_pcall(lua.state, nargs, ffi::LUA_MULTRET, stack_start); + let nargs = args.push_into_stack_multi(&lua)?; + // Call the function + let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start); if ret != ffi::LUA_OK { - return Err(pop_error(lua.state, ret)); - } - let nresults = ffi::lua_gettop(lua.state) - stack_start; - let mut results = args; // Reuse MultiValue container - assert_stack(lua.state, 2); - for _ in 0..nresults { - results.push_front(lua.pop_value()); + return Err(pop_error(state, ret)); } - ffi::lua_pop(lua.state, 1); - results - }; - R::from_lua_multi(results, lua) + // Get the results + let nresults = ffi::lua_gettop(state) - stack_start; + R::from_stack_multi(nresults, &lua) + } } - /// Returns a Feature that, when polled, calls `self`, passing `args` as function arguments, + /// Returns a future that, when polled, calls `self`, passing `args` as function arguments, /// and drives the execution. /// - /// Internally it wraps the function to an [`AsyncThread`]. - /// - /// Requires `feature = "async"` + /// Internally it wraps the function to an [`AsyncThread`]. The returned type implements + /// `Future>` and can be awaited. /// /// # Examples /// /// ``` /// use std::time::Duration; - /// use futures_timer::Delay; /// # use mlua::{Lua, Result}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// # let lua = Lua::new(); /// /// let sleep = lua.create_async_function(move |_lua, n: u64| async move { - /// Delay::new(Duration::from_millis(n)).await; + /// tokio::time::sleep(Duration::from_millis(n)).await; /// Ok(()) /// })?; /// - /// sleep.call_async(10).await?; + /// sleep.call_async::<()>(10).await?; /// /// # Ok(()) /// # } /// ``` /// - /// [`AsyncThread`]: crate::AsyncThread + /// [`AsyncThread`]: crate::thread::AsyncThread #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn call_async<'fut, A, R>(&self, args: A) -> LocalBoxFuture<'fut, Result> + pub fn call_async(&self, args: impl IntoLuaMulti) -> AsyncCallFuture where - 'lua: 'fut, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut, + R: FromLuaMulti, { - let lua = self.0.lua; - match lua.create_recycled_thread(self.clone()) { - Ok(t) => { - let mut t = t.into_async(args); - t.set_recyclable(true); - Box::pin(t) - } - Err(e) => Box::pin(future::err(e)), - } + let lua = self.0.lua.lock(); + AsyncCallFuture(unsafe { + lua.create_recycled_thread(self).and_then(|th| { + let mut th = th.into_async(args)?; + th.set_recyclable(true); + Ok(th) + }) + }) } /// Returns a function that, when called, calls `self`, passing `args` as the first set of @@ -169,16 +276,16 @@ impl<'lua> Function<'lua> { /// "#).eval()?; /// /// let bound_a = sum.bind(1)?; - /// assert_eq!(bound_a.call::<_, u32>(2)?, 1 + 2); + /// assert_eq!(bound_a.call::(2)?, 1 + 2); /// /// let bound_a_and_b = sum.bind(13)?.bind(57)?; - /// assert_eq!(bound_a_and_b.call::<_, u32>(())?, 13 + 57); + /// assert_eq!(bound_a_and_b.call::(())?, 13 + 57); /// /// # Ok(()) /// # } /// ``` - pub fn bind>(&self, args: A) -> Result> { - unsafe extern "C" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int { + pub fn bind(&self, args: impl IntoLuaMulti) -> Result { + unsafe extern "C-unwind" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int { let nargs = ffi::lua_gettop(state); let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int; ffi::luaL_checkstack(state, nbinds, ptr::null()); @@ -186,35 +293,43 @@ impl<'lua> Function<'lua> { for i in 0..nbinds { ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2)); } - ffi::lua_rotate(state, 1, nbinds); + if nargs > 0 { + ffi::lua_rotate(state, 1, nbinds); + } nargs + nbinds } - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); - let args = args.to_lua_multi(lua)?; + let args = args.into_lua_multi(lua.lua())?; let nargs = args.len() as c_int; + if nargs == 0 { + return Ok(self.clone()); + } + if nargs + 1 > ffi::LUA_MAX_UPVALUES { return Err(Error::BindError); } let args_wrapper = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, nargs + 3)?; + let _sg = StackGuard::new(state); + check_stack(state, nargs + 3)?; - ffi::lua_pushinteger(lua.state, nargs as ffi::lua_Integer); - for arg in args { + ffi::lua_pushinteger(state, nargs as ffi::lua_Integer); + for arg in &args { lua.push_value(arg)?; } - protect_lua!(lua.state, nargs + 1, 1, fn(state) { + protect_lua!(state, nargs + 1, 1, fn(state) { ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state)); })?; Function(lua.pop_ref()) }; + let lua = lua.lua(); lua.load( r#" local func, args_wrapper = ... @@ -224,41 +339,148 @@ impl<'lua> Function<'lua> { "#, ) .try_cache() - .set_name("_mlua_bind")? - .call((self.clone(), args_wrapper)) + .set_name("=__mlua_bind") + .call((self, args_wrapper)) + } + + /// Returns the environment of the Lua function. + /// + /// By default Lua functions shares a global environment. + /// + /// This function always returns `None` for Rust/C functions. + pub fn environment(&self) -> Option
{ + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 1); + + lua.push_ref(&self.0); + if ffi::lua_iscfunction(state, -1) != 0 { + return None; + } + + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_getfenv(state, -1); + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + for i in 1..=255 { + // Traverse upvalues until we find the _ENV one + match ffi::lua_getupvalue(state, -1, i) { + s if s.is_null() => break, + s if std::ffi::CStr::from_ptr(s as _) == c"_ENV" => break, + _ => ffi::lua_pop(state, 1), + } + } + + if ffi::lua_type(state, -1) != ffi::LUA_TTABLE { + return None; + } + Some(Table(lua.pop_ref())) + } + } + + /// Sets the environment of the Lua function. + /// + /// The environment is a table that is used as the global environment for the function. + /// Returns `true` if environment successfully changed, `false` otherwise. + /// + /// This function does nothing for Rust/C functions. + pub fn set_environment(&self, env: Table) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + lua.push_ref(&self.0); + if ffi::lua_iscfunction(state, -1) != 0 { + return Ok(false); + } + + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + { + lua.push_ref(&env.0); + ffi::lua_setfenv(state, -2); + } + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + for i in 1..=255 { + match ffi::lua_getupvalue(state, -1, i) { + s if s.is_null() => return Ok(false), + s if std::ffi::CStr::from_ptr(s as _) == c"_ENV" => { + ffi::lua_pop(state, 1); + // Create an anonymous function with the new environment + let f_with_env = lua + .lua() + .load("return _ENV") + .set_environment(env) + .try_cache() + .into_function()?; + lua.push_ref(&f_with_env.0); + ffi::lua_upvaluejoin(state, -2, i, -1, 1); + break; + } + _ => ffi::lua_pop(state, 1), + } + } + + Ok(true) + } } /// Returns information about the function. /// - /// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function. + /// Corresponds to the `>Snu` (`>Sn` for Luau) what mask for + /// [`lua_getinfo`] when applied to the function. /// /// [`lua_getinfo`]: https://www.lua.org/manual/5.4/manual.html#lua_getinfo pub fn info(&self) -> FunctionInfo { - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 1); + let _sg = StackGuard::new(state); + assert_stack(state, 1); let mut ar: ffi::lua_Debug = mem::zeroed(); lua.push_ref(&self.0); + #[cfg(not(feature = "luau"))] - let res = ffi::lua_getinfo(lua.state, cstr!(">Sn"), &mut ar); + let res = ffi::lua_getinfo(state, cstr!(">Snu"), &mut ar); + #[cfg(not(feature = "luau"))] + mlua_assert!(res != 0, "lua_getinfo failed with `>Snu`"); + + #[cfg(feature = "luau")] + let res = ffi::lua_getinfo(state, -1, cstr!("snau"), &mut ar); #[cfg(feature = "luau")] - let res = ffi::lua_getinfo(lua.state, -1, cstr!("sn"), &mut ar); - mlua_assert!(res != 0, "lua_getinfo failed with `>Sn`"); + mlua_assert!(res != 0, "lua_getinfo failed with `snau`"); FunctionInfo { - name: ptr_to_cstr_bytes(ar.name).map(|s| s.to_vec()), + name: ptr_to_lossy_str(ar.name).map(|s| s.into_owned()), #[cfg(not(feature = "luau"))] - name_what: ptr_to_cstr_bytes(ar.namewhat).map(|s| s.to_vec()), + name_what: match ptr_to_str(ar.namewhat) { + Some("") => None, + val => val, + }, #[cfg(feature = "luau")] name_what: None, - what: ptr_to_cstr_bytes(ar.what).map(|s| s.to_vec()), - source: ptr_to_cstr_bytes(ar.source).map(|s| s.to_vec()), - short_src: ptr_to_cstr_bytes(&ar.short_src as *const _).map(|s| s.to_vec()), - line_defined: ar.linedefined as i32, + what: ptr_to_str(ar.what).unwrap_or("main"), + source: ptr_to_lossy_str(ar.source).map(|s| s.into_owned()), + #[cfg(not(feature = "luau"))] + short_src: ptr_to_lossy_str(ar.short_src.as_ptr()).map(|s| s.into_owned()), + #[cfg(feature = "luau")] + short_src: ptr_to_lossy_str(ar.short_src).map(|s| s.into_owned()), + line_defined: linenumber_to_usize(ar.linedefined), #[cfg(not(feature = "luau"))] - last_line_defined: ar.lastlinedefined as i32, + last_line_defined: linenumber_to_usize(ar.lastlinedefined), + #[cfg(feature = "luau")] + last_line_defined: None, + #[cfg(not(feature = "luau"))] + num_upvalues: ar.nups as _, + #[cfg(feature = "luau")] + num_upvalues: ar.nupvals, + #[cfg(not(any(feature = "lua51", feature = "luajit")))] + num_params: ar.nparams, + #[cfg(not(any(feature = "lua51", feature = "luajit")))] + is_vararg: ar.isvararg != 0, } } } @@ -268,46 +490,401 @@ impl<'lua> Function<'lua> { /// If `strip` is true, the binary representation may not include all debug information /// about the function, to save space. /// - /// For Luau a [Compiler] can be used to compile Lua chunks to bytecode. + /// For Luau a [`Compiler`] can be used to compile Lua chunks to bytecode. /// - /// [Compiler]: crate::chunk::Compiler + /// [`Compiler`]: crate::chunk::Compiler #[cfg(not(feature = "luau"))] #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn dump(&self, strip: bool) -> Vec { - use std::os::raw::c_void; - use std::slice; - - unsafe extern "C" fn writer( + unsafe extern "C-unwind" fn writer( _state: *mut ffi::lua_State, buf: *const c_void, buf_len: usize, - data: *mut c_void, + data_ptr: *mut c_void, ) -> c_int { - let data = &mut *(data as *mut Vec); - let buf = slice::from_raw_parts(buf as *const u8, buf_len); - data.extend_from_slice(buf); + // If `data` is null, then it's a signal that write is finished. + if !data_ptr.is_null() && buf_len > 0 { + let data = &mut *(data_ptr as *mut Vec); + let buf = slice::from_raw_parts(buf as *const u8, buf_len); + data.extend_from_slice(buf); + } 0 } - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); let mut data: Vec = Vec::new(); unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 1); + let _sg = StackGuard::new(state); + assert_stack(state, 1); lua.push_ref(&self.0); let data_ptr = &mut data as *mut Vec as *mut c_void; - let strip = if strip { 1 } else { 0 }; - ffi::lua_dump(lua.state, writer, data_ptr, strip); - ffi::lua_pop(lua.state, 1); + ffi::lua_dump(state, writer, data_ptr, strip as i32); + ffi::lua_pop(state, 1); } data } + + /// Retrieves recorded coverage information about this Lua function including inner calls. + /// + /// This function takes a callback as an argument and calls it providing [`CoverageInfo`] + /// snapshot per each executed inner function. + /// + /// Recording of coverage information is controlled by [`Compiler::set_coverage_level`] option. + /// + /// [`Compiler::set_coverage_level`]: crate::chunk::Compiler::set_coverage_level + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn coverage(&self, func: F) + where + F: FnMut(CoverageInfo), + { + use std::ffi::CStr; + use std::os::raw::c_char; + + unsafe extern "C-unwind" fn callback( + data: *mut c_void, + function: *const c_char, + line_defined: c_int, + depth: c_int, + hits: *const c_int, + size: usize, + ) { + let function = if !function.is_null() { + Some(CStr::from_ptr(function).to_string_lossy().to_string()) + } else { + None + }; + let rust_callback = &*(data as *const RefCell); + if let Ok(mut rust_callback) = rust_callback.try_borrow_mut() { + // Call the Rust callback with CoverageInfo + rust_callback(CoverageInfo { + function, + line_defined, + depth, + hits: slice::from_raw_parts(hits, size).to_vec(), + }); + } + } + + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 1); + + lua.push_ref(&self.0); + let func = RefCell::new(func); + let func_ptr = &func as *const RefCell as *mut c_void; + ffi::lua_getcoverage(state, -1, func_ptr, callback::); + } + } + + /// Converts this function to a generic C pointer. + /// + /// There is no way to convert the pointer back to its original value. + /// + /// Typically this function is used only for hashing and debug information. + #[inline] + pub fn to_pointer(&self) -> *const c_void { + self.0.to_pointer() + } + + /// Creates a deep clone of the Lua function. + /// + /// Copies the function prototype and all its upvalues to the + /// newly created function. + /// This function returns shallow clone (same handle) for Rust/C functions. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn deep_clone(&self) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + lua.push_ref(&self.0); + if ffi::lua_iscfunction(state, -1) != 0 { + return Ok(self.clone()); + } + + if lua.unlikely_memory_error() { + ffi::lua_clonefunction(state, -1); + } else { + protect_lua!(state, 1, 1, fn(state) ffi::lua_clonefunction(state, -1))?; + } + Ok(Function(lua.pop_ref())) + } + } } -impl<'lua> PartialEq for Function<'lua> { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 +struct WrappedFunction(pub(crate) Callback); + +#[cfg(feature = "async")] +struct WrappedAsyncFunction(pub(crate) AsyncCallback); + +impl Function { + /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] + /// trait. + #[inline] + pub fn wrap(func: F) -> impl IntoLua + where + F: LuaNativeFn> + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + E: ExternalError, + { + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).into_lua_err()?.push_into_stack_multi(lua) + })) + } + + /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait. + pub fn wrap_mut(func: F) -> impl IntoLua + where + F: LuaNativeFnMut> + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + E: ExternalError, + { + let func = RefCell::new(func); + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).into_lua_err()?.push_into_stack_multi(lua) + })) + } + + /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] + /// trait. + /// + /// This function is similar to [`Function::wrap`] but any returned `Result` will be converted + /// to a `ok, err` tuple without throwing an exception. + #[inline] + pub fn wrap_raw(func: F) -> impl IntoLua + where + F: LuaNativeFn + MaybeSend + 'static, + F::Output: IntoLuaMulti, + A: FromLuaMulti, + { + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).push_into_stack_multi(lua) + })) + } + + /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait. + /// + /// This function is similar to [`Function::wrap_mut`] but any returned `Result` will be + /// converted to a `ok, err` tuple without throwing an exception. + #[inline] + pub fn wrap_raw_mut(func: F) -> impl IntoLua + where + F: LuaNativeFnMut + MaybeSend + 'static, + F::Output: IntoLuaMulti, + A: FromLuaMulti, + { + let func = RefCell::new(func); + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).push_into_stack_multi(lua) + })) + } + + /// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`] + /// trait. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn wrap_async(func: F) -> impl IntoLua + where + F: LuaNativeAsyncFn> + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + E: ExternalError, + { + WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe { + let args = match A::from_stack_args(nargs, 1, None, rawlua) { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = func.call(args); + Box::pin(async move { fut.await.into_lua_err()?.push_into_stack_multi(lua.raw_lua()) }) + })) + } + + /// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`] + /// trait. + /// + /// This function is similar to [`Function::wrap_async`] but any returned `Result` will be + /// converted to a `ok, err` tuple without throwing an exception. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn wrap_raw_async(func: F) -> impl IntoLua + where + F: LuaNativeAsyncFn + MaybeSend + 'static, + F::Output: IntoLuaMulti, + A: FromLuaMulti, + { + WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe { + let args = match A::from_stack_args(nargs, 1, None, rawlua) { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = func.call(args); + Box::pin(async move { fut.await.push_into_stack_multi(lua.raw_lua()) }) + })) + } +} + +impl IntoLua for WrappedFunction { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + lua.lock().create_callback(self.0).map(Value::Function) + } +} + +#[cfg(feature = "async")] +impl IntoLua for WrappedAsyncFunction { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + lua.lock().create_async_callback(self.0).map(Value::Function) } } + +impl LuaType for Function { + const TYPE_ID: c_int = ffi::LUA_TFUNCTION; +} + +/// Future for asynchronous function calls. +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct AsyncCallFuture(Result>); + +#[cfg(feature = "async")] +impl AsyncCallFuture { + pub(crate) fn error(err: Error) -> Self { + AsyncCallFuture(Err(err)) + } +} + +#[cfg(feature = "async")] +impl Future for AsyncCallFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + match &mut this.0 { + Ok(thread) => pin!(thread).poll(cx), + Err(err) => Poll::Ready(Err(err.clone())), + } + } +} + +/// A trait for types that can be used as Lua functions. +pub trait LuaNativeFn { + type Output; + + fn call(&self, args: A) -> Self::Output; +} + +/// A trait for types with mutable state that can be used as Lua functions. +pub trait LuaNativeFnMut { + type Output; + + fn call(&mut self, args: A) -> Self::Output; +} + +/// A trait for types that returns a future and can be used as Lua functions. +#[cfg(feature = "async")] +pub trait LuaNativeAsyncFn { + type Output; + + fn call(&self, args: A) -> impl Future + MaybeSend + 'static; +} + +macro_rules! impl_lua_native_fn { + ($($A:ident),*) => { + impl LuaNativeFn<($($A,)*)> for FN + where + FN: Fn($($A,)*) -> R + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&self, args: ($($A,)*)) -> Self::Output { + let ($($A,)*) = args; + self($($A,)*) + } + } + + impl LuaNativeFnMut<($($A,)*)> for FN + where + FN: FnMut($($A,)*) -> R + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&mut self, args: ($($A,)*)) -> Self::Output { + let ($($A,)*) = args; + self($($A,)*) + } + } + + #[cfg(feature = "async")] + impl LuaNativeAsyncFn<($($A,)*)> for FN + where + FN: Fn($($A,)*) -> Fut + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + Fut: Future + MaybeSend + 'static, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&self, args: ($($A,)*)) -> impl Future + MaybeSend + 'static { + let ($($A,)*) = args; + self($($A,)*) + } + } + }; +} + +impl_lua_native_fn!(); +impl_lua_native_fn!(A); +impl_lua_native_fn!(A, B); +impl_lua_native_fn!(A, B, C); +impl_lua_native_fn!(A, B, C, D); +impl_lua_native_fn!(A, B, C, D, E); +impl_lua_native_fn!(A, B, C, D, E, F); +impl_lua_native_fn!(A, B, C, D, E, F, G); +impl_lua_native_fn!(A, B, C, D, E, F, G, H); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P); + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(Function: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(Function: Send, Sync); + + #[cfg(all(feature = "async", feature = "send"))] + static_assertions::assert_impl_all!(AsyncCallFuture<()>: Send); +} diff --git a/src/hook.rs b/src/hook.rs deleted file mode 100644 index c9e37144..00000000 --- a/src/hook.rs +++ /dev/null @@ -1,351 +0,0 @@ -use std::cell::UnsafeCell; -#[cfg(not(feature = "luau"))] -use std::ops::{BitOr, BitOrAssign}; -use std::os::raw::c_int; - -use crate::ffi::{self, lua_Debug}; -use crate::lua::Lua; -use crate::util::ptr_to_cstr_bytes; - -/// Contains information about currently executing Lua code. -/// -/// The `Debug` structure is provided as a parameter to the hook function set with -/// [`Lua::set_hook`]. You may call the methods on this structure to retrieve information about the -/// Lua code executing at the time that the hook function was called. Further information can be -/// found in the Lua [documentation][lua_doc]. -/// -/// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#lua_Debug -/// [`Lua::set_hook`]: crate::Lua::set_hook -pub struct Debug<'lua> { - lua: &'lua Lua, - ar: ActivationRecord, - #[cfg(feature = "luau")] - level: c_int, -} - -impl<'lua> Debug<'lua> { - #[cfg(not(feature = "luau"))] - pub(crate) fn new(lua: &'lua Lua, ar: *mut lua_Debug) -> Self { - Debug { - lua, - ar: ActivationRecord::Borrowed(ar), - } - } - - pub(crate) fn new_owned(lua: &'lua Lua, _level: c_int, ar: lua_Debug) -> Self { - Debug { - lua, - ar: ActivationRecord::Owned(UnsafeCell::new(ar)), - #[cfg(feature = "luau")] - level: _level, - } - } - - /// Returns the specific event that triggered the hook. - /// - /// For [Lua 5.1] `DebugEvent::TailCall` is used for return events to indicate a return - /// from a function that did a tail call. - /// - /// [Lua 5.1]: https://www.lua.org/manual/5.1/manual.html#pdf-LUA_HOOKTAILRET - #[cfg(not(feature = "luau"))] - #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] - pub fn event(&self) -> DebugEvent { - unsafe { - match (*self.ar.get()).event { - ffi::LUA_HOOKCALL => DebugEvent::Call, - ffi::LUA_HOOKRET => DebugEvent::Ret, - ffi::LUA_HOOKTAILCALL => DebugEvent::TailCall, - ffi::LUA_HOOKLINE => DebugEvent::Line, - ffi::LUA_HOOKCOUNT => DebugEvent::Count, - event => DebugEvent::Unknown(event), - } - } - } - - /// Corresponds to the `n` what mask. - pub fn names(&self) -> DebugNames<'lua> { - unsafe { - #[cfg(not(feature = "luau"))] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, cstr!("n"), self.ar.get()) != 0, - "lua_getinfo failed with `n`" - ); - #[cfg(feature = "luau")] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, self.level, cstr!("n"), self.ar.get()) != 0, - "lua_getinfo failed with `n`" - ); - - DebugNames { - name: ptr_to_cstr_bytes((*self.ar.get()).name), - #[cfg(not(feature = "luau"))] - name_what: ptr_to_cstr_bytes((*self.ar.get()).namewhat), - #[cfg(feature = "luau")] - name_what: None, - } - } - } - - /// Corresponds to the `S` what mask. - pub fn source(&self) -> DebugSource<'lua> { - unsafe { - #[cfg(not(feature = "luau"))] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, cstr!("S"), self.ar.get()) != 0, - "lua_getinfo failed with `S`" - ); - #[cfg(feature = "luau")] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, self.level, cstr!("s"), self.ar.get()) != 0, - "lua_getinfo failed with `s`" - ); - - DebugSource { - source: ptr_to_cstr_bytes((*self.ar.get()).source), - short_src: ptr_to_cstr_bytes((*self.ar.get()).short_src.as_ptr()), - line_defined: (*self.ar.get()).linedefined as i32, - #[cfg(not(feature = "luau"))] - last_line_defined: (*self.ar.get()).lastlinedefined as i32, - what: ptr_to_cstr_bytes((*self.ar.get()).what), - } - } - } - - /// Corresponds to the `l` what mask. Returns the current line. - pub fn curr_line(&self) -> i32 { - unsafe { - #[cfg(not(feature = "luau"))] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, cstr!("l"), self.ar.get()) != 0, - "lua_getinfo failed with `l`" - ); - #[cfg(feature = "luau")] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, self.level, cstr!("l"), self.ar.get()) != 0, - "lua_getinfo failed with `l`" - ); - - (*self.ar.get()).currentline as i32 - } - } - - /// Corresponds to the `t` what mask. Returns true if the hook is in a function tail call, false - /// otherwise. - #[cfg(not(feature = "luau"))] - #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] - pub fn is_tail_call(&self) -> bool { - unsafe { - mlua_assert!( - ffi::lua_getinfo(self.lua.state, cstr!("t"), self.ar.get()) != 0, - "lua_getinfo failed with `t`" - ); - (*self.ar.get()).currentline != 0 - } - } - - /// Corresponds to the `u` what mask. - pub fn stack(&self) -> DebugStack { - unsafe { - #[cfg(not(feature = "luau"))] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, cstr!("u"), self.ar.get()) != 0, - "lua_getinfo failed with `u`" - ); - #[cfg(feature = "luau")] - mlua_assert!( - ffi::lua_getinfo(self.lua.state, self.level, cstr!("a"), self.ar.get()) != 0, - "lua_getinfo failed with `a`" - ); - - #[cfg(not(feature = "luau"))] - let stack = DebugStack { - num_ups: (*self.ar.get()).nups as i32, - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - num_params: (*self.ar.get()).nparams as i32, - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - is_vararg: (*self.ar.get()).isvararg != 0, - }; - #[cfg(feature = "luau")] - let stack = DebugStack { - num_ups: (*self.ar.get()).nupvals as i32, - num_params: (*self.ar.get()).nparams as i32, - is_vararg: (*self.ar.get()).isvararg != 0, - }; - stack - } - } -} - -enum ActivationRecord { - #[cfg(not(feature = "luau"))] - Borrowed(*mut lua_Debug), - Owned(UnsafeCell), -} - -impl ActivationRecord { - #[inline] - fn get(&self) -> *mut lua_Debug { - match self { - #[cfg(not(feature = "luau"))] - ActivationRecord::Borrowed(x) => *x, - ActivationRecord::Owned(x) => x.get(), - } - } -} - -/// Represents a specific event that triggered the hook. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum DebugEvent { - Call, - Ret, - TailCall, - Line, - Count, - Unknown(c_int), -} - -#[derive(Clone, Debug)] -pub struct DebugNames<'a> { - pub name: Option<&'a [u8]>, - pub name_what: Option<&'a [u8]>, -} - -#[derive(Clone, Debug)] -pub struct DebugSource<'a> { - pub source: Option<&'a [u8]>, - pub short_src: Option<&'a [u8]>, - pub line_defined: i32, - #[cfg(not(feature = "luau"))] - pub last_line_defined: i32, - pub what: Option<&'a [u8]>, -} - -#[derive(Copy, Clone, Debug)] -pub struct DebugStack { - pub num_ups: i32, - /// Requires `feature = "lua54/lua53/lua52/luau"` - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luau" - ))] - pub num_params: i32, - /// Requires `feature = "lua54/lua53/lua52/luau"` - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luau" - ))] - pub is_vararg: bool, -} - -/// Determines when a hook function will be called by Lua. -#[cfg(not(feature = "luau"))] -#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] -#[derive(Clone, Copy, Debug, Default)] -pub struct HookTriggers { - /// Before a function call. - pub on_calls: bool, - /// When Lua returns from a function. - pub on_returns: bool, - /// Before executing a new line, or returning from a function call. - pub every_line: bool, - /// After a certain number of VM instructions have been executed. When set to `Some(count)`, - /// `count` is the number of VM instructions to execute before calling the hook. - /// - /// # Performance - /// - /// Setting this option to a low value can incur a very high overhead. - pub every_nth_instruction: Option, -} - -#[cfg(not(feature = "luau"))] -impl HookTriggers { - /// Returns a new instance of `HookTriggers` with [`on_calls`] trigger set. - /// - /// [`on_calls`]: #structfield.on_calls - pub fn on_calls() -> Self { - HookTriggers { - on_calls: true, - ..Default::default() - } - } - - /// Returns a new instance of `HookTriggers` with [`on_returns`] trigger set. - /// - /// [`on_returns`]: #structfield.on_returns - pub fn on_returns() -> Self { - HookTriggers { - on_returns: true, - ..Default::default() - } - } - - /// Returns a new instance of `HookTriggers` with [`every_line`] trigger set. - /// - /// [`every_line`]: #structfield.every_line - pub fn every_line() -> Self { - HookTriggers { - every_line: true, - ..Default::default() - } - } - - /// Returns a new instance of `HookTriggers` with [`every_nth_instruction`] trigger set. - /// - /// [`every_nth_instruction`]: #structfield.every_nth_instruction - pub fn every_nth_instruction(n: u32) -> Self { - HookTriggers { - every_nth_instruction: Some(n), - ..Default::default() - } - } - - // Compute the mask to pass to `lua_sethook`. - pub(crate) fn mask(&self) -> c_int { - let mut mask: c_int = 0; - if self.on_calls { - mask |= ffi::LUA_MASKCALL - } - if self.on_returns { - mask |= ffi::LUA_MASKRET - } - if self.every_line { - mask |= ffi::LUA_MASKLINE - } - if self.every_nth_instruction.is_some() { - mask |= ffi::LUA_MASKCOUNT - } - mask - } - - // Returns the `count` parameter to pass to `lua_sethook`, if applicable. Otherwise, zero is - // returned. - pub(crate) fn count(&self) -> c_int { - self.every_nth_instruction.unwrap_or(0) as c_int - } -} - -#[cfg(not(feature = "luau"))] -impl BitOr for HookTriggers { - type Output = Self; - - fn bitor(mut self, rhs: Self) -> Self::Output { - self.on_calls |= rhs.on_calls; - self.on_returns |= rhs.on_returns; - self.every_line |= rhs.every_line; - if self.every_nth_instruction.is_none() && rhs.every_nth_instruction.is_some() { - self.every_nth_instruction = rhs.every_nth_instruction; - } - self - } -} - -#[cfg(not(feature = "luau"))] -impl BitOrAssign for HookTriggers { - fn bitor_assign(&mut self, rhs: Self) { - *self = *self | rhs; - } -} diff --git a/src/lib.rs b/src/lib.rs index 0ccb645f..1a00dba3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,10 +10,10 @@ //! //! # Converting data //! -//! The [`ToLua`] and [`FromLua`] traits allow conversion from Rust types to Lua values and vice +//! The [`IntoLua`] and [`FromLua`] traits allow conversion from Rust types to Lua values and vice //! versa. They are implemented for many data structures found in Rust's standard library. //! -//! For more general conversions, the [`ToLuaMulti`] and [`FromLuaMulti`] traits allow converting +//! For more general conversions, the [`IntoLuaMulti`] and [`FromLuaMulti`] traits allow converting //! between Rust types and *any number* of Lua values. //! //! Most code in `mlua` is generic over implementors of those traits, so in most places the normal @@ -27,123 +27,145 @@ //! //! # Serde support //! -//! The [`LuaSerdeExt`] trait implemented for [`Lua`] allows conversion from Rust types to Lua values -//! and vice versa using serde. Any user defined data type that implements [`serde::Serialize`] or -//! [`serde::Deserialize`] can be converted. +//! The [`LuaSerdeExt`] trait implemented for [`Lua`] allows conversion from Rust types to Lua +//! values and vice versa using serde. Any user defined data type that implements +//! [`serde::Serialize`] or [`serde::Deserialize`] can be converted. //! For convenience, additional functionality to handle `NULL` values and arrays is provided. //! -//! The [`Value`] enum implements [`serde::Serialize`] trait to support serializing Lua values -//! (including [`UserData`]) into Rust values. +//! The [`Value`] enum and other types implement [`serde::Serialize`] trait to support serializing +//! Lua values into Rust values. //! -//! Requires `feature = "serialize"`. +//! Requires `feature = "serde"`. //! //! # Async/await support //! -//! The [`create_async_function`] allows creating non-blocking functions that returns [`Future`]. -//! Lua code with async capabilities can be executed by [`call_async`] family of functions or polling -//! [`AsyncThread`] using any runtime (eg. Tokio). +//! The [`Lua::create_async_function`] allows creating non-blocking functions that returns +//! [`Future`]. Lua code with async capabilities can be executed by [`Function::call_async`] family +//! of functions or polling [`AsyncThread`] using any runtime (eg. Tokio). //! //! Requires `feature = "async"`. //! -//! # `Send` requirement -//! By default `mlua` is `!Send`. This can be changed by enabling `feature = "send"` that adds `Send` requirement -//! to [`Function`]s and [`UserData`]. +//! # `Send` and `Sync` support +//! +//! By default `mlua` is `!Send`. This can be changed by enabling `feature = "send"` that adds +//! `Send` requirement to Rust functions and [`UserData`] types. +//! +//! In this case [`Lua`] object and their types can be send or used from other threads. Internally +//! access to Lua VM is synchronized using a reentrant mutex that can be locked many times within +//! the same thread. //! //! [Lua programming language]: https://www.lua.org/ -//! [`Lua`]: crate::Lua //! [executing]: crate::Chunk::exec //! [evaluating]: crate::Chunk::eval //! [globals]: crate::Lua::globals -//! [`ToLua`]: crate::ToLua -//! [`FromLua`]: crate::FromLua -//! [`ToLuaMulti`]: crate::ToLuaMulti -//! [`FromLuaMulti`]: crate::FromLuaMulti -//! [`Function`]: crate::Function -//! [`UserData`]: crate::UserData -//! [`UserDataFields`]: crate::UserDataFields -//! [`UserDataMethods`]: crate::UserDataMethods -//! [`LuaSerdeExt`]: crate::LuaSerdeExt -//! [`Value`]: crate::Value -//! [`create_async_function`]: crate::Lua::create_async_function -//! [`call_async`]: crate::Function::call_async -//! [`AsyncThread`]: crate::AsyncThread //! [`Future`]: std::future::Future //! [`serde::Serialize`]: https://docs.serde.rs/serde/ser/trait.Serialize.html //! [`serde::Deserialize`]: https://docs.serde.rs/serde/de/trait.Deserialize.html +//! [`AsyncThread`]: crate::thread::AsyncThread -// mlua types in rustdoc of other crates get linked to here. -#![doc(html_root_url = "https://docs.rs/mlua/0.8.0-beta.4")] // Deny warnings inside doc tests / examples. When this isn't present, rustdoc doesn't show *any* // warnings at all. -#![doc(test(attr(deny(warnings))))] #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(not(send), allow(clippy::arc_with_non_send_sync))] +#![allow(unsafe_op_in_unsafe_fn)] #[macro_use] mod macros; -mod chunk; +mod buffer; mod conversion; -mod error; -mod ffi; -mod function; -mod hook; -mod lua; -#[cfg(feature = "luau")] -mod luau; +mod memory; mod multi; mod scope; mod stdlib; -mod string; -mod table; -mod thread; +mod traits; mod types; -mod userdata; -mod userdata_impl; mod util; mod value; +mod vector; +pub mod chunk; +pub mod debug; +pub mod error; +pub mod function; +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +pub mod luau; pub mod prelude; +pub mod state; +pub mod string; +pub mod table; +pub mod thread; +pub mod userdata; -pub use crate::{ffi::lua_CFunction, ffi::lua_State}; +pub use bstr::BString; +pub use ffi::{self, lua_CFunction, lua_State}; -pub use crate::chunk::{AsChunk, Chunk, ChunkMode}; -pub use crate::error::{Error, ExternalError, ExternalResult, Result}; -pub use crate::function::{Function, FunctionInfo}; -pub use crate::hook::{Debug, DebugEvent, DebugNames, DebugSource, DebugStack}; -pub use crate::lua::{GCMode, Lua, LuaOptions}; -pub use crate::multi::Variadic; +#[doc(inline)] +pub use crate::error::{Error, Result}; +#[doc(inline)] +pub use crate::function::Function; +pub use crate::multi::{MultiValue, Variadic}; pub use crate::scope::Scope; +#[doc(inline)] +pub use crate::state::{Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; -pub use crate::string::String; -pub use crate::table::{Table, TableExt, TablePairs, TableSequence}; -pub use crate::thread::{Thread, ThreadStatus}; -pub use crate::types::{Integer, LightUserData, Number, RegistryKey}; +#[doc(inline)] +pub use crate::string::{BorrowedBytes, BorrowedStr, LuaString}; +#[doc(inline)] +pub use crate::table::Table; +#[doc(inline)] +pub use crate::thread::Thread; +#[doc(inline)] +pub use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, ObjectLike}; +pub use crate::types::{ + AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, MaybeSync, Number, RegistryKey, + VmState, +}; +#[doc(inline)] +pub use crate::userdata::AnyUserData; +pub use crate::value::{Nil, Value}; + +// Re-export some types to keep backward compatibility and avoid breaking changes in the public API. +#[doc(hidden)] +pub use crate::chunk::{AsChunk, Chunk, ChunkMode}; +#[cfg(feature = "luau")] +#[doc(hidden)] +pub use crate::chunk::{CompileConstant, Compiler}; +#[doc(hidden)] +pub use crate::error::{ErrorContext, ExternalError, ExternalResult}; +#[doc(hidden)] +pub use crate::string::LuaString as String; +#[doc(hidden)] +pub use crate::table::{TablePairs, TableSequence}; +#[doc(hidden)] +pub use crate::thread::ThreadStatus; +#[doc(hidden)] pub use crate::userdata::{ - AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, + MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataOwned, UserDataRef, + UserDataRefMut, UserDataRegistry, }; -pub use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; #[cfg(not(feature = "luau"))] -pub use crate::hook::HookTriggers; +#[doc(inline)] +pub use crate::debug::HookTriggers; #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] -pub use crate::{chunk::Compiler, types::VmState}; +pub use crate::{buffer::Buffer, vector::Vector}; -#[cfg(feature = "async")] -pub use crate::thread::AsyncThread; - -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] +#[doc(hidden)] +pub use crate::serde::{DeserializeOptions, SerializeOptions}; +#[cfg(feature = "serde")] #[doc(inline)] -pub use crate::serde::{ - de::Options as DeserializeOptions, ser::Options as SerializeOptions, LuaSerdeExt, -}; +pub use crate::{serde::LuaSerdeExt, value::SerializableValue}; -#[cfg(feature = "serialize")] -#[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] +#[cfg(feature = "serde")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] pub mod serde; -#[cfg(any(feature = "mlua_derive"))] +#[cfg(feature = "mlua_derive")] #[allow(unused_imports)] #[macro_use] extern crate mlua_derive; @@ -153,7 +175,7 @@ extern crate mlua_derive; /// This macro allows to write Lua code directly in Rust code. /// /// Rust variables can be referenced from Lua using `$` prefix, as shown in the example below. -/// User's Rust types needs to implement [`UserData`] or [`ToLua`] traits. +/// User's Rust types needs to implement [`UserData`] or [`IntoLua`] traits. /// /// Captured variables are **moved** into the chunk. /// @@ -188,27 +210,31 @@ extern crate mlua_derive; /// /// Other minor limitations: /// -/// - Certain escape codes in string literals don't work. -/// (Specifically: `\a`, `\b`, `\f`, `\v`, `\123` (octal escape codes), `\u`, and `\U`). +/// - Certain escape codes in string literals don't work. (Specifically: `\a`, `\b`, `\f`, `\v`, +/// `\123` (octal escape codes), `\u`, and `\U`). /// /// These are accepted: : `\\`, `\n`, `\t`, `\r`, `\xAB` (hex escape codes), and `\0`. /// /// - The `//` (floor division) operator is unusable, as its start a comment. /// /// Everything else should work. -/// -/// [`AsChunk`]: crate::AsChunk -/// [`UserData`]: crate::UserData -/// [`ToLua`]: crate::ToLua -#[cfg(any(feature = "macros"))] +#[cfg(feature = "macros")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] pub use mlua_derive::chunk; +/// Derive [`FromLua`] for a Rust type. +/// +/// Current implementation generate code that takes [`UserData`] value, borrow it (of the Rust type) +/// and clone. +#[cfg(feature = "macros")] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +pub use mlua_derive::FromLua; + /// Registers Lua module entrypoint. /// /// You can register multiple entrypoints as required. /// -/// ``` +/// ```ignore /// use mlua::{Lua, Result, Table}; /// /// #[mlua::lua_module] @@ -221,6 +247,45 @@ pub use mlua_derive::chunk; /// /// Internally in the code above the compiler defines C function `luaopen_my_module`. /// -#[cfg(any(feature = "module", docsrs))] +/// You can also pass options to the attribute: +/// +/// * name - name of the module, defaults to the name of the function +/// +/// ```ignore +/// #[mlua::lua_module(name = "alt_module")] +/// fn my_module(lua: &Lua) -> Result
{ +/// ... +/// } +/// ``` +/// +/// * skip_memory_check - skip memory allocation checks for some operations. +/// +/// In module mode, mlua runs in an unknown environment and cannot tell whether there are any memory +/// limits or not. As a result, some operations that require memory allocation run in protected +/// mode. Setting this attribute will improve performance of such operations with risk of having +/// uncaught exceptions and memory leaks. +/// +/// ```ignore +/// #[mlua::lua_module(skip_memory_check)] +/// fn my_module(lua: &Lua) -> Result
{ +/// ... +/// } +/// ``` +#[cfg(all(feature = "mlua_derive", any(feature = "module", doc)))] #[cfg_attr(docsrs, doc(cfg(feature = "module")))] pub use mlua_derive::lua_module; + +#[cfg(all(feature = "module", feature = "send"))] +compile_error!("`send` feature is not supported in module mode"); + +pub(crate) mod private { + use super::*; + + pub trait Sealed {} + + impl Sealed for Error {} + impl Sealed for std::result::Result {} + impl Sealed for Lua {} + impl Sealed for Table {} + impl Sealed for AnyUserData {} +} diff --git a/src/lua.rs b/src/lua.rs deleted file mode 100644 index 5aac6c5e..00000000 --- a/src/lua.rs +++ /dev/null @@ -1,3176 +0,0 @@ -use std::any::{Any, TypeId}; -use std::cell::{Ref, RefCell, RefMut, UnsafeCell}; -use std::collections::HashMap; -use std::ffi::{CStr, CString}; -use std::fmt; -use std::marker::PhantomData; -use std::mem::ManuallyDrop; -use std::ops::{Deref, DerefMut}; -use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location}; -use std::sync::{Arc, Mutex}; -use std::{mem, ptr, str}; - -use rustc_hash::FxHashMap; - -use crate::chunk::{AsChunk, Chunk, ChunkMode}; -use crate::error::{Error, Result}; -use crate::ffi; -use crate::function::Function; -use crate::hook::Debug; -use crate::scope::Scope; -use crate::stdlib::StdLib; -use crate::string::String; -use crate::table::Table; -use crate::thread::Thread; -use crate::types::{ - Callback, CallbackUpvalue, DestructedUserdataMT, Integer, LightUserData, LuaRef, MaybeSend, - Number, RegistryKey, -}; -use crate::userdata::{AnyUserData, UserData, UserDataCell}; -use crate::userdata_impl::{StaticUserDataFields, StaticUserDataMethods}; -use crate::util::{ - self, assert_stack, callback_error, check_stack, get_destructed_userdata_metatable, - get_gc_metatable, get_gc_userdata, get_main_state, get_userdata, init_error_registry, - init_gc_metatable, init_userdata_metatable, pop_error, push_gc_userdata, push_string, - push_table, rawset_field, safe_pcall, safe_xpcall, StackGuard, WrappedFailure, -}; -use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti, Value}; - -#[cfg(not(feature = "lua54"))] -use crate::util::push_userdata; -#[cfg(feature = "lua54")] -use crate::{types::WarnCallback, userdata::USER_VALUE_MAXSLOT, util::push_userdata_uv}; - -#[cfg(not(feature = "luau"))] -use crate::{hook::HookTriggers, types::HookCallback}; - -#[cfg(feature = "luau")] -use crate::types::InterruptCallback; -#[cfg(any(feature = "luau", doc))] -use crate::{chunk::Compiler, types::VmState}; - -#[cfg(feature = "async")] -use { - crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}, - futures_core::{ - future::{Future, LocalBoxFuture}, - task::{Context, Poll, Waker}, - }, - futures_task::noop_waker, - futures_util::future::{self, TryFutureExt}, -}; - -#[cfg(feature = "serialize")] -use serde::Serialize; - -/// Top level Lua struct which represents an instance of Lua VM. -#[repr(transparent)] -pub struct Lua(Arc>); - -/// An inner Lua struct which holds a raw Lua state. -pub struct LuaInner { - pub(crate) state: *mut ffi::lua_State, - main_state: *mut ffi::lua_State, - extra: Arc>, - safe: bool, - #[cfg(feature = "luau")] - compiler: Option, - // Lua has lots of interior mutability, should not be RefUnwindSafe - _no_ref_unwind_safe: PhantomData>, -} - -// Data associated with the Lua. -pub(crate) struct ExtraData { - // Same layout as `Lua` - inner: Option>>>, - - registered_userdata: FxHashMap, - registered_userdata_mt: FxHashMap<*const c_void, Option>, - registry_unref_list: Arc>>>, - - #[cfg(not(feature = "send"))] - app_data: RefCell>>, - #[cfg(feature = "send")] - app_data: RefCell>>, - - libs: StdLib, - mem_info: Option>, - - ref_thread: *mut ffi::lua_State, - ref_stack_size: c_int, - ref_stack_top: c_int, - ref_free: Vec, - - // Cache of `WrappedFailure` enums on the ref thread (as userdata) - wrapped_failures_cache: Vec, - // Cache of recycled `MultiValue` containers - multivalue_cache: Vec>, - // Cache of recycled `Thread`s (coroutines) - #[cfg(feature = "async")] - recycled_thread_cache: Vec, - - // Index of `Option` userdata on the ref thread - #[cfg(feature = "async")] - ref_waker_idx: c_int, - - #[cfg(not(feature = "luau"))] - hook_callback: Option, - #[cfg(feature = "lua54")] - warn_callback: Option, - #[cfg(feature = "luau")] - interrupt_callback: Option, - - #[cfg(feature = "luau")] - sandboxed: bool, -} - -#[cfg_attr(any(feature = "lua51", feature = "luajit"), allow(dead_code))] -struct MemoryInfo { - used_memory: isize, - memory_limit: isize, -} - -/// Mode of the Lua garbage collector (GC). -/// -/// In Lua 5.4 GC can work in two modes: incremental and generational. -/// Previous Lua versions support only incremental GC. -/// -/// More information can be found in the Lua [documentation]. -/// -/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5 -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum GCMode { - Incremental, - /// Requires `feature = "lua54"` - #[cfg(any(feature = "lua54"))] - #[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] - Generational, -} - -/// Controls Lua interpreter behavior such as Rust panics handling. -#[derive(Clone, Debug)] -#[non_exhaustive] -pub struct LuaOptions { - /// Catch Rust panics when using [`pcall`]/[`xpcall`]. - /// - /// If disabled, wraps these functions and automatically resumes panic if found. - /// Also in Lua 5.1 adds ability to provide arguments to [`xpcall`] similar to Lua >= 5.2. - /// - /// If enabled, keeps [`pcall`]/[`xpcall`] unmodified. - /// Panics are still automatically resumed if returned to the Rust side. - /// - /// Default: **true** - /// - /// [`pcall`]: https://www.lua.org/manual/5.4/manual.html#pdf-pcall - /// [`xpcall`]: https://www.lua.org/manual/5.4/manual.html#pdf-xpcall - pub catch_rust_panics: bool, - - /// Max size of thread (coroutine) object cache used to execute asynchronous functions. - /// - /// It works on Lua 5.4, LuaJIT (vendored) and Luau, where [`lua_resetthread`] function - /// is available and allows to reuse old coroutines with reset state. - /// - /// Default: **0** (disabled) - /// - /// [`lua_resetthread`]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub thread_cache_size: usize, -} - -impl Default for LuaOptions { - fn default() -> Self { - LuaOptions::new() - } -} - -impl LuaOptions { - /// Returns a new instance of `LuaOptions` with default parameters. - pub const fn new() -> Self { - LuaOptions { - catch_rust_panics: true, - #[cfg(feature = "async")] - thread_cache_size: 0, - } - } - - /// Sets [`catch_rust_panics`] option. - /// - /// [`catch_rust_panics`]: #structfield.catch_rust_panics - #[must_use] - pub const fn catch_rust_panics(mut self, enabled: bool) -> Self { - self.catch_rust_panics = enabled; - self - } - - /// Sets [`thread_cache_size`] option. - /// - /// [`thread_cache_size`]: #structfield.thread_cache_size - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - #[must_use] - pub const fn thread_cache_size(mut self, size: usize) -> Self { - self.thread_cache_size = size; - self - } -} - -#[cfg(feature = "async")] -pub(crate) static ASYNC_POLL_PENDING: u8 = 0; -pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0; - -const WRAPPED_FAILURES_CACHE_SIZE: usize = 32; -const MULTIVALUE_CACHE_SIZE: usize = 32; - -/// Requires `feature = "send"` -#[cfg(feature = "send")] -#[cfg_attr(docsrs, doc(cfg(feature = "send")))] -unsafe impl Send for Lua {} - -#[cfg(not(feature = "module"))] -impl Drop for LuaInner { - fn drop(&mut self) { - unsafe { - let extra = &mut *self.extra.get(); - let drain_iter = extra.wrapped_failures_cache.drain(..); - #[cfg(feature = "async")] - let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..)); - for index in drain_iter { - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, index); - extra.ref_free.push(index); - } - #[cfg(feature = "async")] - { - // Destroy Waker slot - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx); - extra.ref_free.push(extra.ref_waker_idx); - } - #[cfg(feature = "luau")] - { - let callbacks = ffi::lua_callbacks(self.state); - let extra_ptr = (*callbacks).userdata as *mut Arc>; - drop(Box::from_raw(extra_ptr)); - (*callbacks).userdata = ptr::null_mut(); - } - mlua_debug_assert!( - ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top - && extra.ref_stack_top as usize == extra.ref_free.len(), - "reference leak detected" - ); - ffi::lua_close(self.main_state); - } - } -} - -impl Drop for ExtraData { - fn drop(&mut self) { - #[cfg(feature = "module")] - unsafe { - ManuallyDrop::drop(&mut self.inner.take().unwrap()) - }; - - *mlua_expect!(self.registry_unref_list.lock(), "unref list poisoned") = None; - if let Some(mem_info) = self.mem_info { - drop(unsafe { Box::from_raw(mem_info.as_ptr()) }); - } - } -} - -impl fmt::Debug for Lua { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Lua({:p})", self.state) - } -} - -impl Deref for Lua { - type Target = LuaInner; - - fn deref(&self) -> &Self::Target { - unsafe { &*(*self.0).get() } - } -} - -impl DerefMut for Lua { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *(*self.0).get() } - } -} - -impl Lua { - /// Creates a new Lua state and loads the **safe** subset of the standard libraries. - /// - /// # Safety - /// The created Lua state would have _some_ safety guarantees and would not allow to load unsafe - /// standard libraries or C modules. - /// - /// See [`StdLib`] documentation for a list of unsafe modules that cannot be loaded. - /// - /// [`StdLib`]: crate::StdLib - #[allow(clippy::new_without_default)] - pub fn new() -> Lua { - mlua_expect!( - Self::new_with(StdLib::ALL_SAFE, LuaOptions::default()), - "can't create new safe Lua state" - ) - } - - /// Creates a new Lua state and loads all the standard libraries. - /// - /// # Safety - /// The created Lua state would not have safety guarantees and would allow to load C modules. - pub unsafe fn unsafe_new() -> Lua { - Self::unsafe_new_with(StdLib::ALL, LuaOptions::default()) - } - - /// Creates a new Lua state and loads the specified safe subset of the standard libraries. - /// - /// Use the [`StdLib`] flags to specify the libraries you want to load. - /// - /// # Safety - /// The created Lua state would have _some_ safety guarantees and would not allow to load unsafe - /// standard libraries or C modules. - /// - /// See [`StdLib`] documentation for a list of unsafe modules that cannot be loaded. - /// - /// [`StdLib`]: crate::StdLib - pub fn new_with(libs: StdLib, options: LuaOptions) -> Result { - #[cfg(not(feature = "luau"))] - if libs.contains(StdLib::DEBUG) { - return Err(Error::SafetyError( - "the unsafe `debug` module can't be loaded using safe `new_with`".to_string(), - )); - } - #[cfg(feature = "luajit")] - { - if libs.contains(StdLib::FFI) { - return Err(Error::SafetyError( - "the unsafe `ffi` module can't be loaded using safe `new_with`".to_string(), - )); - } - } - - let mut lua = unsafe { Self::inner_new(libs, options) }; - - #[cfg(not(feature = "luau"))] - if libs.contains(StdLib::PACKAGE) { - mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules"); - } - lua.safe = true; - - Ok(lua) - } - - /// Creates a new Lua state and loads the specified subset of the standard libraries. - /// - /// Use the [`StdLib`] flags to specify the libraries you want to load. - /// - /// # Safety - /// The created Lua state will not have safety guarantees and allow to load C modules. - /// - /// [`StdLib`]: crate::StdLib - pub unsafe fn unsafe_new_with(libs: StdLib, options: LuaOptions) -> Lua { - #[cfg(not(feature = "luau"))] - ffi::keep_lua_symbols(); - Self::inner_new(libs, options) - } - - unsafe fn inner_new(libs: StdLib, options: LuaOptions) -> Lua { - #[cfg_attr( - any(feature = "lua51", feature = "luajit", feature = "luau"), - allow(dead_code) - )] - unsafe extern "C" fn allocator( - extra_data: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, - ) -> *mut c_void { - use std::alloc; - - let mem_info = &mut *(extra_data as *mut MemoryInfo); - - if nsize == 0 { - // Free memory - if !ptr.is_null() { - let layout = - alloc::Layout::from_size_align_unchecked(osize, ffi::SYS_MIN_ALIGN); - alloc::dealloc(ptr as *mut u8, layout); - mem_info.used_memory -= osize as isize; - } - return ptr::null_mut(); - } - - // Are we fit to the memory limits? - let mut mem_diff = nsize as isize; - if !ptr.is_null() { - mem_diff -= osize as isize; - } - let new_used_memory = mem_info.used_memory + mem_diff; - if mem_info.memory_limit > 0 && new_used_memory > mem_info.memory_limit { - return ptr::null_mut(); - } - - let new_layout = alloc::Layout::from_size_align_unchecked(nsize, ffi::SYS_MIN_ALIGN); - - if ptr.is_null() { - // Allocate new memory - let new_ptr = alloc::alloc(new_layout) as *mut c_void; - if !new_ptr.is_null() { - mem_info.used_memory += mem_diff; - } - return new_ptr; - } - - // Reallocate memory - let old_layout = alloc::Layout::from_size_align_unchecked(osize, ffi::SYS_MIN_ALIGN); - let new_ptr = alloc::realloc(ptr as *mut u8, old_layout, nsize) as *mut c_void; - - if !new_ptr.is_null() { - mem_info.used_memory += mem_diff; - } else if !ptr.is_null() && nsize < osize { - // Should not happen - alloc::handle_alloc_error(new_layout); - } - - new_ptr - } - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - let mem_info = Box::into_raw(Box::new(MemoryInfo { - used_memory: 0, - memory_limit: 0, - })); - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - let state = ffi::lua_newstate(allocator, mem_info as *mut c_void); - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - let state = ffi::luaL_newstate(); - - ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1); - ffi::lua_pop(state, 1); - - let lua = Lua::init_from_ptr(state); - let extra = &mut *lua.extra.get(); - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - { - extra.mem_info = ptr::NonNull::new(mem_info); - } - - mlua_expect!( - load_from_std_lib(state, libs), - "Error during loading standard libraries" - ); - extra.libs |= libs; - - if !options.catch_rust_panics { - mlua_expect!( - (|| -> Result<()> { - let _sg = StackGuard::new(lua.state); - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - ffi::lua_rawgeti(lua.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_pushvalue(lua.state, ffi::LUA_GLOBALSINDEX); - - ffi::lua_pushcfunction(lua.state, safe_pcall); - rawset_field(lua.state, -2, "pcall")?; - - ffi::lua_pushcfunction(lua.state, safe_xpcall); - rawset_field(lua.state, -2, "xpcall")?; - - Ok(()) - })(), - "Error during applying option `catch_rust_panics`" - ) - } - - #[cfg(feature = "async")] - if options.thread_cache_size > 0 { - extra.recycled_thread_cache = Vec::with_capacity(options.thread_cache_size); - } - - #[cfg(feature = "luau")] - mlua_expect!(lua.prepare_luau_state(), "Error preparing Luau state"); - - lua - } - - /// Constructs a new Lua instance from an existing raw state. - /// - /// Once called, a returned Lua state is cached in the registry and can be retrieved - /// by calling this function again. - #[allow(clippy::missing_safety_doc)] - pub unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Lua { - let main_state = get_main_state(state).unwrap_or(state); - let main_state_top = ffi::lua_gettop(main_state); - - if let Some(lua) = Lua::make_from_ptr(state) { - return lua; - } - - mlua_expect!( - (|state| { - init_error_registry(state)?; - - // Create the internal metatables and place them in the registry - // to prevent them from being garbage collected. - - init_gc_metatable::>>(state, None)?; - init_gc_metatable::(state, None)?; - init_gc_metatable::(state, None)?; - #[cfg(feature = "async")] - { - init_gc_metatable::(state, None)?; - init_gc_metatable::(state, None)?; - init_gc_metatable::(state, None)?; - init_gc_metatable::>(state, None)?; - } - - // Init serde metatables - #[cfg(feature = "serialize")] - crate::serde::init_metatables(state)?; - - Ok::<_, Error>(()) - })(main_state), - "Error during Lua construction", - ); - - // Create ref stack thread and place it in the registry to prevent it from being garbage - // collected. - let ref_thread = mlua_expect!( - protect_lua!(state, 0, 0, |state| { - let thread = ffi::lua_newthread(state); - ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); - thread - }), - "Error while creating ref thread", - ); - - // Create empty Waker slot on the ref thread - #[cfg(feature = "async")] - let ref_waker_idx = { - mlua_expect!( - push_gc_userdata::>(ref_thread, None), - "Error while creating Waker slot" - ); - ffi::lua_gettop(ref_thread) - }; - let ref_stack_top = ffi::lua_gettop(ref_thread); - - // Create ExtraData - - let extra = Arc::new(UnsafeCell::new(ExtraData { - inner: None, - registered_userdata: FxHashMap::default(), - registered_userdata_mt: FxHashMap::default(), - registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))), - app_data: RefCell::new(HashMap::new()), - ref_thread, - libs: StdLib::NONE, - mem_info: None, - // We need 1 extra stack space to move values in and out of the ref stack. - ref_stack_size: ffi::LUA_MINSTACK - 1, - ref_stack_top, - ref_free: Vec::new(), - wrapped_failures_cache: Vec::with_capacity(WRAPPED_FAILURES_CACHE_SIZE), - multivalue_cache: Vec::with_capacity(MULTIVALUE_CACHE_SIZE), - #[cfg(feature = "async")] - recycled_thread_cache: Vec::new(), - #[cfg(feature = "async")] - ref_waker_idx, - #[cfg(not(feature = "luau"))] - hook_callback: None, - #[cfg(feature = "lua54")] - warn_callback: None, - #[cfg(feature = "luau")] - interrupt_callback: None, - #[cfg(feature = "luau")] - sandboxed: false, - })); - - mlua_expect!( - (|state| { - push_gc_userdata(state, Arc::clone(&extra))?; - protect_lua!(state, 1, 0, fn(state) { - let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key); - }) - })(main_state), - "Error while storing extra data", - ); - - // Register `DestructedUserdataMT` type - get_destructed_userdata_metatable(main_state); - let destructed_mt_ptr = ffi::lua_topointer(main_state, -1); - let destructed_mt_typeid = Some(TypeId::of::()); - (*extra.get()) - .registered_userdata_mt - .insert(destructed_mt_ptr, destructed_mt_typeid); - ffi::lua_pop(main_state, 1); - - mlua_debug_assert!( - ffi::lua_gettop(main_state) == main_state_top, - "stack leak during creation" - ); - assert_stack(main_state, ffi::LUA_MINSTACK); - - // Set Luau callbacks userdata to extra data - // We can use global callbacks userdata since we don't allow C modules in Luau - #[cfg(feature = "luau")] - { - let extra_raw = Box::into_raw(Box::new(Arc::clone(&extra))); - (*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void; - } - - let inner = Arc::new(UnsafeCell::new(LuaInner { - state, - main_state, - extra: Arc::clone(&extra), - safe: false, - #[cfg(feature = "luau")] - compiler: None, - _no_ref_unwind_safe: PhantomData, - })); - - (*extra.get()).inner = Some(ManuallyDrop::new(Arc::clone(&inner))); - #[cfg(not(feature = "module"))] - Arc::decrement_strong_count(Arc::as_ptr(&inner)); - - Lua(inner) - } - - /// Loads the specified subset of the standard libraries into an existing Lua state. - /// - /// Use the [`StdLib`] flags to specify the libraries you want to load. - /// - /// [`StdLib`]: crate::StdLib - pub fn load_from_std_lib(&self, libs: StdLib) -> Result<()> { - #[cfg(not(feature = "luau"))] - if self.safe && libs.contains(StdLib::DEBUG) { - return Err(Error::SafetyError( - "the unsafe `debug` module can't be loaded in safe mode".to_string(), - )); - } - #[cfg(feature = "luajit")] - { - if self.safe && libs.contains(StdLib::FFI) { - return Err(Error::SafetyError( - "the unsafe `ffi` module can't be loaded in safe mode".to_string(), - )); - } - } - - let res = unsafe { load_from_std_lib(self.main_state, libs) }; - - // If `package` library loaded into a safe lua state then disable C modules - let extra = unsafe { &mut *self.extra.get() }; - #[cfg(not(feature = "luau"))] - { - let curr_libs = extra.libs; - if self.safe && (curr_libs ^ (curr_libs | libs)).contains(StdLib::PACKAGE) { - mlua_expect!(self.disable_c_modules(), "Error during disabling C modules"); - } - } - extra.libs |= libs; - - res - } - - /// Loads module `modname` into an existing Lua state using the specified entrypoint - /// function. - /// - /// Internally calls the Lua function `func` with the string `modname` as an argument, - /// sets the call result to `package.loaded[modname]` and returns copy of the result. - /// - /// If `package.loaded[modname]` value is not nil, returns copy of the value without - /// calling the function. - /// - /// If the function does not return a non-nil value then this method assigns true to - /// `package.loaded[modname]`. - /// - /// Behavior is similar to Lua's [`require`] function. - /// - /// [`require`]: https://www.lua.org/manual/5.4/manual.html#pdf-require - pub fn load_from_function<'lua, S, T>( - &'lua self, - modname: &S, - func: Function<'lua>, - ) -> Result - where - S: AsRef<[u8]> + ?Sized, - T: FromLua<'lua>, - { - let loaded = unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - protect_lua!(self.state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(self.pop_ref()) - }; - - let modname = self.create_string(modname)?; - let value = match loaded.raw_get(modname.clone())? { - Value::Nil => { - let result = match func.call(modname.clone())? { - Value::Nil => Value::Boolean(true), - res => res, - }; - loaded.raw_set(modname, result.clone())?; - result - } - res => res, - }; - T::from_lua(value, self) - } - - /// Unloads module `modname`. - /// - /// Removes module from the [`package.loaded`] table which allows to load it again. - /// It does not support unloading binary Lua modules since they are internally cached and can be - /// unloaded only by closing Lua state. - /// - /// [`package.loaded`]: https://www.lua.org/manual/5.4/manual.html#pdf-package.loaded - pub fn unload(&self, modname: &S) -> Result<()> - where - S: AsRef<[u8]> + ?Sized, - { - let loaded = unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - protect_lua!(self.state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(self.pop_ref()) - }; - - let modname = self.create_string(modname)?; - loaded.raw_remove(modname)?; - Ok(()) - } - - /// Consumes and leaks `Lua` object, returning a static reference `&'static Lua`. - /// - /// This function is useful when the `Lua` object is supposed to live for the remainder - /// of the program's life. - /// In particular in asynchronous context this will allow to spawn Lua tasks to execute - /// in background. - /// - /// Dropping the returned reference will cause a memory leak. If this is not acceptable, - /// the reference should first be wrapped with the [`Lua::from_static`] function producing a `Lua`. - /// This `Lua` object can then be dropped which will properly release the allocated memory. - /// - /// [`Lua::from_static`]: #method.from_static - #[doc(hidden)] - pub fn into_static(self) -> &'static Self { - Box::leak(Box::new(self)) - } - - /// Constructs a `Lua` from a static reference to it. - /// - /// # Safety - /// This function is unsafe because improper use may lead to memory problems or undefined behavior. - #[doc(hidden)] - pub unsafe fn from_static(lua: &'static Lua) -> Self { - *Box::from_raw(lua as *const Lua as *mut Lua) - } - - // Executes module entrypoint function, which returns only one Value. - // The returned value then pushed onto the stack. - #[doc(hidden)] - #[cfg(not(tarpaulin_include))] - pub unsafe fn entrypoint<'lua, A, R, F>(self, func: F) -> Result - where - A: FromLuaMulti<'lua>, - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - let entrypoint_inner = |lua: &'lua Lua, func: F| { - let nargs = ffi::lua_gettop(lua.state); - check_stack(lua.state, 3)?; - - let mut args = MultiValue::new(); - args.reserve(nargs as usize); - for _ in 0..nargs { - args.push_front(lua.pop_value()); - } - - // We create callback rather than call `func` directly to catch errors - // with attached stacktrace. - let callback = lua.create_callback(Box::new(move |lua, args| { - func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - }))?; - callback.call(args) - }; - - match entrypoint_inner(mem::transmute(&self), func) { - Ok(res) => { - self.push_value(res)?; - Ok(1) - } - Err(err) => { - self.push_value(Value::Error(err))?; - let state = self.state; - // Lua (self) must be dropped before triggering longjmp - drop(self); - ffi::lua_error(state) - } - } - } - - // A simple module entrypoint without arguments - #[doc(hidden)] - #[cfg(not(tarpaulin_include))] - pub unsafe fn entrypoint1<'lua, R, F>(self, func: F) -> Result - where - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, - { - self.entrypoint(move |lua, _: ()| func(lua)) - } - - /// Enables (or disables) sandbox mode on this Lua instance. - /// - /// This method, in particular: - /// - Set all libraries to read-only - /// - Set all builtin metatables to read-only - /// - Set globals to read-only (and activates safeenv) - /// - Setup local environment table that performs writes locally and proxies reads - /// to the global environment. - /// - /// # Examples - /// - /// ``` - /// # use mlua::{Lua, Result}; - /// # fn main() -> Result<()> { - /// let lua = Lua::new(); - /// - /// lua.sandbox(true)?; - /// lua.load("var = 123").exec()?; - /// assert_eq!(lua.globals().get::<_, u32>("var")?, 123); - /// - /// // Restore the global environment (clear changes made in sandbox) - /// lua.sandbox(false)?; - /// assert_eq!(lua.globals().get::<_, Option>("var")?, None); - /// # Ok(()) - /// # } - /// ``` - /// - /// Requires `feature = "luau"` - #[cfg(any(feature = "luau", docsrs))] - #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - pub fn sandbox(&self, enabled: bool) -> Result<()> { - unsafe { - let extra = &mut *self.extra.get(); - if extra.sandboxed != enabled { - let state = self.main_state; - check_stack(state, 3)?; - protect_lua!(state, 0, 0, |state| { - if enabled { - ffi::luaL_sandbox(state, 1); - ffi::luaL_sandboxthread(state); - } else { - // Restore original `LUA_GLOBALSINDEX` - self.ref_thread_exec(|ref_thread| { - ffi::lua_xpush(ref_thread, state, ffi::LUA_GLOBALSINDEX); - ffi::lua_replace(state, ffi::LUA_GLOBALSINDEX); - }); - ffi::luaL_sandbox(state, 0); - } - })?; - extra.sandboxed = enabled; - } - Ok(()) - } - } - - /// Sets a 'hook' function that will periodically be called as Lua code executes. - /// - /// When exactly the hook function is called depends on the contents of the `triggers` - /// parameter, see [`HookTriggers`] for more details. - /// - /// The provided hook function can error, and this error will be propagated through the Lua code - /// that was executing at the time the hook was triggered. This can be used to implement a - /// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and - /// erroring once an instruction limit has been reached. - /// - /// # Example - /// - /// Shows each line number of code being executed by the Lua interpreter. - /// - /// ``` - /// # use mlua::{Lua, HookTriggers, Result}; - /// # fn main() -> Result<()> { - /// let lua = Lua::new(); - /// lua.set_hook(HookTriggers::every_line(), |_lua, debug| { - /// println!("line {}", debug.curr_line()); - /// Ok(()) - /// })?; - /// - /// lua.load(r#" - /// local x = 2 + 3 - /// local y = x * 63 - /// local z = string.len(x..", "..y) - /// "#).exec() - /// # } - /// ``` - /// - /// [`HookTriggers`]: crate::HookTriggers - /// [`HookTriggers.every_nth_instruction`]: crate::HookTriggers::every_nth_instruction - #[cfg(not(feature = "luau"))] - #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] - pub fn set_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> - where - F: 'static + MaybeSend + Fn(&Lua, Debug) -> Result<()>, - { - unsafe extern "C" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { - let lua = match Lua::make_from_ptr(state) { - Some(lua) => lua, - None => return, - }; - let extra = lua.extra.get(); - callback_error_ext(state, extra, move |_| { - let debug = Debug::new(&lua, ar); - let hook_cb = (*lua.extra.get()).hook_callback.clone(); - let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc"); - if Arc::strong_count(&hook_cb) > 2 { - return Ok(()); // Don't allow recursion - } - hook_cb(&lua, debug) - }) - } - - unsafe { - let state = get_main_state(self.main_state).ok_or(Error::MainThreadNotAvailable)?; - (*self.extra.get()).hook_callback = Some(Arc::new(callback)); - ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count()); - } - Ok(()) - } - - /// Removes any hook previously set by `set_hook`. - /// - /// This function has no effect if a hook was not previously set. - #[cfg(not(feature = "luau"))] - #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] - pub fn remove_hook(&self) { - unsafe { - // If main_state is not available, then sethook wasn't called. - let state = match get_main_state(self.main_state) { - Some(state) => state, - None => return, - }; - (*self.extra.get()).hook_callback = None; - ffi::lua_sethook(state, None, 0, 0); - } - } - - /// Sets an 'interrupt' function that will periodically be called by Luau VM. - /// - /// Any Luau code is guaranteed to call this handler "eventually" - /// (in practice this can happen at any function call or at any loop iteration). - /// - /// The provided interrupt function can error, and this error will be propagated through - /// the Luau code that was executing at the time the interrupt was triggered. - /// Also this can be used to implement continuous execution limits by instructing Luau VM to yield - /// by returning [`VmState::Yield`]. - /// - /// This is similar to [`Lua::set_hook`] but in more simplified form. - /// - /// # Example - /// - /// Periodically yield Luau VM to suspend execution. - /// - /// ``` - /// # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; - /// # use mlua::{Lua, Result, ThreadStatus, VmState}; - /// # fn main() -> Result<()> { - /// let lua = Lua::new(); - /// let count = Arc::new(AtomicU64::new(0)); - /// lua.set_interrupt(move || { - /// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 { - /// return Ok(VmState::Yield); - /// } - /// Ok(VmState::Continue) - /// }); - /// - /// let co = lua.create_thread( - /// lua.load(r#" - /// local b = 0 - /// for _, x in ipairs({1, 2, 3}) do b += x end - /// "#) - /// .into_function()?, - /// )?; - /// while co.status() == ThreadStatus::Resumable { - /// co.resume(())?; - /// } - /// # Ok(()) - /// # } - /// ``` - #[cfg(any(feature = "luau", docsrs))] - #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - pub fn set_interrupt(&self, callback: F) - where - F: 'static + MaybeSend + Fn() -> Result, - { - unsafe extern "C" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) { - if gc >= 0 { - // We don't support GC interrupts since they cannot survive Lua exceptions - return; - } - let extra = match extra_data(state) { - Some(e) => e.get(), - None => return, - }; - let result = callback_error_ext(state, extra, move |_| { - let interrupt_cb = (*extra).interrupt_callback.clone(); - let interrupt_cb = - mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc"); - if Arc::strong_count(&interrupt_cb) > 2 { - return Ok(VmState::Continue); // Don't allow recursion - } - interrupt_cb() - }); - match result { - VmState::Continue => {} - VmState::Yield => { - ffi::lua_yield(state, 0); - } - } - } - - unsafe { - (*self.extra.get()).interrupt_callback = Some(Arc::new(callback)); - (*ffi::lua_callbacks(self.main_state)).interrupt = Some(interrupt_proc); - } - } - - /// Removes any 'interrupt' previously set by `set_interrupt`. - /// - /// This function has no effect if an 'interrupt' was not previously set. - #[cfg(any(feature = "luau", docsrs))] - #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - pub fn remove_interrupt(&self) { - unsafe { - (*self.extra.get()).interrupt_callback = None; - (*ffi::lua_callbacks(self.main_state)).interrupt = None; - } - } - - /// Sets the warning function to be used by Lua to emit warnings. - /// - /// Requires `feature = "lua54"` - #[cfg(feature = "lua54")] - #[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] - pub fn set_warning_function(&self, callback: F) - where - F: 'static + MaybeSend + Fn(&Lua, &CStr, bool) -> Result<()>, - { - unsafe extern "C" fn warn_proc(ud: *mut c_void, msg: *const c_char, tocont: c_int) { - let state = ud as *mut ffi::lua_State; - let lua = match Lua::make_from_ptr(state) { - Some(lua) => lua, - None => return, - }; - let extra = lua.extra.get(); - callback_error_ext(state, extra, move |_| { - let cb = mlua_expect!( - (*lua.extra.get()).warn_callback.as_ref(), - "no warning callback set in warn_proc" - ); - let msg = CStr::from_ptr(msg); - cb(&lua, msg, tocont != 0) - }); - } - - let state = self.main_state; - unsafe { - (*self.extra.get()).warn_callback = Some(Box::new(callback)); - ffi::lua_setwarnf(state, Some(warn_proc), state as *mut c_void); - } - } - - /// Removes warning function previously set by `set_warning_function`. - /// - /// This function has no effect if a warning function was not previously set. - /// - /// Requires `feature = "lua54"` - #[cfg(feature = "lua54")] - #[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] - pub fn remove_warning_function(&self) { - unsafe { - (*self.extra.get()).warn_callback = None; - ffi::lua_setwarnf(self.main_state, None, ptr::null_mut()); - } - } - - /// Emits a warning with the given message. - /// - /// A message in a call with `tocont` set to `true` should be continued in another call to this function. - /// - /// Requires `feature = "lua54"` - #[cfg(feature = "lua54")] - #[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] - pub fn warning>>(&self, msg: S, tocont: bool) -> Result<()> { - let msg = CString::new(msg).map_err(|err| Error::RuntimeError(err.to_string()))?; - unsafe { ffi::lua_warning(self.state, msg.as_ptr(), if tocont { 1 } else { 0 }) }; - Ok(()) - } - - /// Gets information about the interpreter runtime stack. - /// - /// This function returns [`Debug`] structure that can be used to get information about the function - /// executing at a given level. Level `0` is the current running function, whereas level `n+1` is the - /// function that has called level `n` (except for tail calls, which do not count in the stack). - /// - /// [`Debug`]: crate::hook::Debug - pub fn inspect_stack(&self, level: usize) -> Option { - unsafe { - let mut ar: ffi::lua_Debug = mem::zeroed(); - let level = level as c_int; - #[cfg(not(feature = "luau"))] - if ffi::lua_getstack(self.state, level, &mut ar) == 0 { - return None; - } - #[cfg(feature = "luau")] - if ffi::lua_getinfo(self.state, level, cstr!(""), &mut ar) == 0 { - return None; - } - Some(Debug::new_owned(self, level, ar)) - } - } - - /// Returns the amount of memory (in bytes) currently used inside this Lua state. - pub fn used_memory(&self) -> usize { - unsafe { - match (*self.extra.get()).mem_info.map(|x| x.as_ref()) { - Some(mem_info) => mem_info.used_memory as usize, - None => { - // Get data from the Lua GC - let used_kbytes = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNT, 0); - let used_kbytes_rem = ffi::lua_gc(self.main_state, ffi::LUA_GCCOUNTB, 0); - (used_kbytes as usize) * 1024 + (used_kbytes_rem as usize) - } - } - } - } - - /// Sets a memory limit (in bytes) on this Lua state. - /// - /// Once an allocation occurs that would pass this memory limit, - /// a `Error::MemoryError` is generated instead. - /// Returns previous limit (zero means no limit). - /// - /// Does not work on module mode where Lua state is managed externally. - /// - /// Requires `feature = "lua54/lua53/lua52"` - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - pub fn set_memory_limit(&self, memory_limit: usize) -> Result { - unsafe { - match (*self.extra.get()).mem_info.map(|mut x| x.as_mut()) { - Some(mem_info) => { - let prev_limit = mem_info.memory_limit as usize; - mem_info.memory_limit = memory_limit as isize; - Ok(prev_limit) - } - None => Err(Error::MemoryLimitNotAvailable), - } - } - } - - /// Returns true if the garbage collector is currently running automatically. - /// - /// Requires `feature = "lua54/lua53/lua52/luau"` - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luau" - ))] - pub fn gc_is_running(&self) -> bool { - unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCISRUNNING, 0) != 0 } - } - - /// Stop the Lua GC from running - pub fn gc_stop(&self) { - unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSTOP, 0) }; - } - - /// Restarts the Lua GC if it is not running - pub fn gc_restart(&self) { - unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCRESTART, 0) }; - } - - /// Perform a full garbage-collection cycle. - /// - /// It may be necessary to call this function twice to collect all currently unreachable - /// objects. Once to finish the current gc cycle, and once to start and finish the next cycle. - pub fn gc_collect(&self) -> Result<()> { - unsafe { - check_stack(self.main_state, 2)?; - protect_lua!(self.main_state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0)) - } - } - - /// Steps the garbage collector one indivisible step. - /// - /// Returns true if this has finished a collection cycle. - pub fn gc_step(&self) -> Result { - self.gc_step_kbytes(0) - } - - /// Steps the garbage collector as though memory had been allocated. - /// - /// if `kbytes` is 0, then this is the same as calling `gc_step`. Returns true if this step has - /// finished a collection cycle. - pub fn gc_step_kbytes(&self, kbytes: c_int) -> Result { - unsafe { - check_stack(self.main_state, 3)?; - protect_lua!(self.main_state, 0, 0, |state| { - ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0 - }) - } - } - - /// Sets the 'pause' value of the collector. - /// - /// Returns the previous value of 'pause'. More information can be found in the Lua - /// [documentation]. - /// - /// For Luau this parameter sets GC goal - /// - /// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5 - pub fn gc_set_pause(&self, pause: c_int) -> c_int { - unsafe { - #[cfg(not(feature = "luau"))] - return ffi::lua_gc(self.main_state, ffi::LUA_GCSETPAUSE, pause); - #[cfg(feature = "luau")] - return ffi::lua_gc(self.main_state, ffi::LUA_GCSETGOAL, pause); - } - } - - /// Sets the 'step multiplier' value of the collector. - /// - /// Returns the previous value of the 'step multiplier'. More information can be found in the - /// Lua [documentation]. - /// - /// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5 - pub fn gc_set_step_multiplier(&self, step_multiplier: c_int) -> c_int { - unsafe { ffi::lua_gc(self.main_state, ffi::LUA_GCSETSTEPMUL, step_multiplier) } - } - - /// Changes the collector to incremental mode with the given parameters. - /// - /// Returns the previous mode (always `GCMode::Incremental` in Lua < 5.4). - /// More information can be found in the Lua [documentation]. - /// - /// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5.1 - pub fn gc_inc(&self, pause: c_int, step_multiplier: c_int, step_size: c_int) -> GCMode { - let state = self.main_state; - - #[cfg(any( - feature = "lua53", - feature = "lua52", - feature = "lua51", - feature = "luajit", - feature = "luau" - ))] - unsafe { - if pause > 0 { - #[cfg(not(feature = "luau"))] - ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, pause); - #[cfg(feature = "luau")] - ffi::lua_gc(state, ffi::LUA_GCSETGOAL, pause); - } - - if step_multiplier > 0 { - ffi::lua_gc(state, ffi::LUA_GCSETSTEPMUL, step_multiplier); - } - - #[cfg(feature = "luau")] - if step_size > 0 { - ffi::lua_gc(state, ffi::LUA_GCSETSTEPSIZE, step_size); - } - #[cfg(not(feature = "luau"))] - let _ = step_size; // Ignored - - GCMode::Incremental - } - - #[cfg(feature = "lua54")] - let prev_mode = - unsafe { ffi::lua_gc(state, ffi::LUA_GCINC, pause, step_multiplier, step_size) }; - #[cfg(feature = "lua54")] - match prev_mode { - ffi::LUA_GCINC => GCMode::Incremental, - ffi::LUA_GCGEN => GCMode::Generational, - _ => unreachable!(), - } - } - - /// Changes the collector to generational mode with the given parameters. - /// - /// Returns the previous mode. More information about the generational GC - /// can be found in the Lua 5.4 [documentation][lua_doc]. - /// - /// Requires `feature = "lua54"` - /// - /// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#2.5.2 - #[cfg(any(feature = "lua54"))] - #[cfg_attr(docsrs, doc(cfg(feature = "lua54")))] - pub fn gc_gen(&self, minor_multiplier: c_int, major_multiplier: c_int) -> GCMode { - let state = self.main_state; - let prev_mode = - unsafe { ffi::lua_gc(state, ffi::LUA_GCGEN, minor_multiplier, major_multiplier) }; - match prev_mode { - ffi::LUA_GCGEN => GCMode::Generational, - ffi::LUA_GCINC => GCMode::Incremental, - _ => unreachable!(), - } - } - - /// Sets a default Luau compiler (with custom options). - /// - /// This compiler will be used by default to load all Lua chunks - /// including via `require` function. - /// - /// See [`Compiler`] for details and possible options. - /// - /// Requires `feature = "luau"` - #[cfg(any(feature = "luau", doc))] - #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - pub fn set_compiler(&self, compiler: Compiler) { - unsafe { (*self.0.get()).compiler = Some(compiler) }; - } - - /// Returns Lua source code as a `Chunk` builder type. - /// - /// In order to actually compile or run the resulting code, you must call [`Chunk::exec`] or - /// similar on the returned builder. Code is not even parsed until one of these methods is - /// called. - /// - /// [`Chunk::exec`]: crate::Chunk::exec - #[track_caller] - pub fn load<'lua, 'a, S>(&'lua self, chunk: &'a S) -> Chunk<'lua, 'a> - where - S: AsChunk<'lua> + ?Sized, - { - let name = chunk - .name() - .unwrap_or_else(|| Location::caller().to_string()); - - Chunk { - lua: self, - source: chunk.source(), - name: Some(name), - env: chunk.env(self), - mode: chunk.mode(), - #[cfg(feature = "luau")] - compiler: self.compiler.clone(), - } - } - - pub(crate) fn load_chunk<'lua>( - &'lua self, - source: &[u8], - name: Option<&CStr>, - env: Option>, - mode: Option, - ) -> Result> { - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 1)?; - - let mode_str = match mode { - Some(ChunkMode::Binary) => cstr!("b"), - Some(ChunkMode::Text) => cstr!("t"), - None => cstr!("bt"), - }; - - match ffi::luaL_loadbufferx( - self.state, - source.as_ptr() as *const c_char, - source.len(), - name.map(|n| n.as_ptr()).unwrap_or_else(ptr::null), - mode_str, - ) { - ffi::LUA_OK => { - if let Some(env) = env { - self.push_value(env)?; - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - ffi::lua_setupvalue(self.state, -2, 1); - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_setfenv(self.state, -2); - } - Ok(Function(self.pop_ref())) - } - err => Err(pop_error(self.state, err)), - } - } - } - - /// Create and return an interned Lua string. Lua strings can be arbitrary [u8] data including - /// embedded nulls, so in addition to `&str` and `&String`, you can also pass plain `&[u8]` - /// here. - pub fn create_string(&self, s: &S) -> Result - where - S: AsRef<[u8]> + ?Sized, - { - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 3)?; - push_string(self.state, s)?; - Ok(String(self.pop_ref())) - } - } - - /// Creates and returns a new empty table. - pub fn create_table(&self) -> Result
{ - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - protect_lua!(self.state, 0, 1, fn(state) ffi::lua_newtable(state))?; - Ok(Table(self.pop_ref())) - } - } - - /// Creates and returns a new empty table, with the specified capacity. - /// `narr` is a hint for how many elements the table will have as a sequence; - /// `nrec` is a hint for how many other elements the table will have. - /// Lua may use these hints to preallocate memory for the new table. - pub fn create_table_with_capacity(&self, narr: c_int, nrec: c_int) -> Result
{ - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 3)?; - push_table(self.state, narr, nrec)?; - Ok(Table(self.pop_ref())) - } - } - - /// Creates a table and fills it with values from an iterator. - pub fn create_table_from<'lua, K, V, I>(&'lua self, iter: I) -> Result> - where - K: ToLua<'lua>, - V: ToLua<'lua>, - I: IntoIterator, - { - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 6)?; - - let iter = iter.into_iter(); - let lower_bound = iter.size_hint().0; - push_table(self.state, 0, lower_bound as c_int)?; - for (k, v) in iter { - self.push_value(k.to_lua(self)?)?; - self.push_value(v.to_lua(self)?)?; - protect_lua!(self.state, 3, 1, fn(state) ffi::lua_rawset(state, -3))?; - } - - Ok(Table(self.pop_ref())) - } - } - - /// Creates a table from an iterator of values, using `1..` as the keys. - pub fn create_sequence_from<'lua, T, I>(&'lua self, iter: I) -> Result> - where - T: ToLua<'lua>, - I: IntoIterator, - { - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 5)?; - - let iter = iter.into_iter(); - let lower_bound = iter.size_hint().0; - push_table(self.state, lower_bound as c_int, 0)?; - for (i, v) in iter.enumerate() { - self.push_value(v.to_lua(self)?)?; - protect_lua!(self.state, 2, 1, |state| { - ffi::lua_rawseti(state, -2, (i + 1) as Integer); - })?; - } - - Ok(Table(self.pop_ref())) - } - } - - /// Wraps a Rust function or closure, creating a callable Lua function handle to it. - /// - /// The function's return value is always a `Result`: If the function returns `Err`, the error - /// is raised as a Lua error, which can be caught using `(x)pcall` or bubble up to the Rust code - /// that invoked the Lua code. This allows using the `?` operator to propagate errors through - /// intermediate Lua code. - /// - /// If the function returns `Ok`, the contained value will be converted to one or more Lua - /// values. For details on Rust-to-Lua conversions, refer to the [`ToLua`] and [`ToLuaMulti`] - /// traits. - /// - /// # Examples - /// - /// Create a function which prints its argument: - /// - /// ``` - /// # use mlua::{Lua, Result}; - /// # fn main() -> Result<()> { - /// # let lua = Lua::new(); - /// let greet = lua.create_function(|_, name: String| { - /// println!("Hello, {}!", name); - /// Ok(()) - /// }); - /// # let _ = greet; // used - /// # Ok(()) - /// # } - /// ``` - /// - /// Use tuples to accept multiple arguments: - /// - /// ``` - /// # use mlua::{Lua, Result}; - /// # fn main() -> Result<()> { - /// # let lua = Lua::new(); - /// let print_person = lua.create_function(|_, (name, age): (String, u8)| { - /// println!("{} is {} years old!", name, age); - /// Ok(()) - /// }); - /// # let _ = print_person; // used - /// # Ok(()) - /// # } - /// ``` - /// - /// [`ToLua`]: crate::ToLua - /// [`ToLuaMulti`]: crate::ToLuaMulti - pub fn create_function<'lua, A, R, F>(&'lua self, func: F) -> Result> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - self.create_callback(Box::new(move |lua, args| { - func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })) - } - - /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. - /// - /// This is a version of [`create_function`] that accepts a FnMut argument. Refer to - /// [`create_function`] for more information about the implementation. - /// - /// [`create_function`]: #method.create_function - pub fn create_function_mut<'lua, A, R, F>(&'lua self, func: F) -> Result> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - let func = RefCell::new(func); - self.create_function(move |lua, args| { - (*func - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?)(lua, args) - }) - } - - /// Wraps a C function, creating a callable Lua function handle to it. - /// - /// # Safety - /// This function is unsafe because provides a way to execute unsafe C function. - pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { - check_stack(self.state, 1)?; - ffi::lua_pushcfunction(self.state, func); - Ok(Function(self.pop_ref())) - } - - /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. - /// - /// While executing the function Rust will poll Future and if the result is not ready, call - /// `yield()` passing internal representation of a `Poll::Pending` value. - /// - /// The function must be called inside Lua coroutine ([`Thread`]) to be able to suspend its execution. - /// An executor should be used to poll [`AsyncThread`] and mlua will take a provided Waker - /// in that case. Otherwise noop waker will be used if try to call the function outside of Rust - /// executors. - /// - /// The family of `call_async()` functions takes care about creating [`Thread`]. - /// - /// Requires `feature = "async"` - /// - /// # Examples - /// - /// Non blocking sleep: - /// - /// ``` - /// use std::time::Duration; - /// use futures_timer::Delay; - /// use mlua::{Lua, Result}; - /// - /// async fn sleep(_lua: &Lua, n: u64) -> Result<&'static str> { - /// Delay::new(Duration::from_millis(n)).await; - /// Ok("done") - /// } - /// - /// #[tokio::main] - /// async fn main() -> Result<()> { - /// let lua = Lua::new(); - /// lua.globals().set("sleep", lua.create_async_function(sleep)?)?; - /// let res: String = lua.load("return sleep(...)").call_async(100).await?; // Sleep 100ms - /// assert_eq!(res, "done"); - /// Ok(()) - /// } - /// ``` - /// - /// [`Thread`]: crate::Thread - /// [`AsyncThread`]: crate::AsyncThread - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn create_async_function<'lua, A, R, F, FR>(&'lua self, func: F) -> Result> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - self.create_async_callback(Box::new(move |lua, args| { - let args = match A::from_lua_multi(args, lua) { - Ok(args) => args, - Err(e) => return Box::pin(future::err(e)), - }; - Box::pin(func(lua, args).and_then(move |ret| future::ready(ret.to_lua_multi(lua)))) - })) - } - - /// Wraps a Lua function into a new thread (or coroutine). - /// - /// Equivalent to `coroutine.create`. - pub fn create_thread<'lua>(&'lua self, func: Function<'lua>) -> Result> { - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 3)?; - - let thread_state = protect_lua!(self.state, 0, 1, |state| ffi::lua_newthread(state))?; - self.push_ref(&func.0); - ffi::lua_xmove(self.state, thread_state, 1); - - Ok(Thread(self.pop_ref())) - } - } - - /// Wraps a Lua function into a new or recycled thread (coroutine). - #[cfg(feature = "async")] - pub(crate) fn create_recycled_thread<'lua>( - &'lua self, - func: Function<'lua>, - ) -> Result> { - #[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", - ))] - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 1)?; - - let extra = &mut *self.extra.get(); - if let Some(index) = extra.recycled_thread_cache.pop() { - let thread_state = ffi::lua_tothread(extra.ref_thread, index); - self.push_ref(&func.0); - ffi::lua_xmove(self.state, thread_state, 1); - - #[cfg(feature = "luau")] - { - // Inherit `LUA_GLOBALSINDEX` from the caller - ffi::lua_xpush(self.state, thread_state, ffi::LUA_GLOBALSINDEX); - ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); - } - - return Ok(Thread(LuaRef { lua: self, index })); - } - }; - self.create_thread(func) - } - - /// Resets thread (coroutine) and returns to the cache for later use. - #[cfg(feature = "async")] - #[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", - ))] - pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { - let extra = &mut *self.extra.get(); - let thread_state = ffi::lua_tothread(extra.ref_thread, thread.0.index); - if extra.recycled_thread_cache.len() < extra.recycled_thread_cache.capacity() { - #[cfg(feature = "lua54")] - let status = ffi::lua_resetthread(thread_state); - #[cfg(feature = "lua54")] - if status != ffi::LUA_OK { - return; - } - #[cfg(all(feature = "luajit", feature = "vendored"))] - ffi::lua_resetthread(self.state, thread_state); - #[cfg(feature = "luau")] - ffi::lua_resetthread(thread_state); - extra.recycled_thread_cache.push(thread.0.index); - thread.0.index = 0; - } - } - - /// Create a Lua userdata object from a custom userdata type. - pub fn create_userdata(&self, data: T) -> Result - where - T: 'static + MaybeSend + UserData, - { - unsafe { self.make_userdata(UserDataCell::new(data)) } - } - - /// Create a Lua userdata object from a custom serializable userdata type. - /// - /// Requires `feature = "serialize"` - #[cfg(feature = "serialize")] - #[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] - pub fn create_ser_userdata(&self, data: T) -> Result - where - T: 'static + MaybeSend + UserData + Serialize, - { - unsafe { self.make_userdata(UserDataCell::new_ser(data)) } - } - - /// Returns a handle to the global environment. - pub fn globals(&self) -> Table { - unsafe { - let _sg = StackGuard::new(self.state); - assert_stack(self.state, 1); - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_pushvalue(self.state, ffi::LUA_GLOBALSINDEX); - Table(self.pop_ref()) - } - } - - /// Returns a handle to the active `Thread`. For calls to `Lua` this will be the main Lua thread, - /// for parameters given to a callback, this will be whatever Lua thread called the callback. - pub fn current_thread(&self) -> Thread { - unsafe { - let _sg = StackGuard::new(self.state); - assert_stack(self.state, 1); - ffi::lua_pushthread(self.state); - Thread(self.pop_ref()) - } - } - - /// Calls the given function with a `Scope` parameter, giving the function the ability to create - /// userdata and callbacks from rust types that are !Send or non-'static. - /// - /// The lifetime of any function or userdata created through `Scope` lasts only until the - /// completion of this method call, on completion all such created values are automatically - /// dropped and Lua references to them are invalidated. If a script accesses a value created - /// through `Scope` outside of this method, a Lua error will result. Since we can ensure the - /// lifetime of values created through `Scope`, and we know that `Lua` cannot be sent to another - /// thread while `Scope` is live, it is safe to allow !Send datatypes and whose lifetimes only - /// outlive the scope lifetime. - /// - /// Inside the scope callback, all handles created through Scope will share the same unique 'lua - /// lifetime of the parent `Lua`. This allows scoped and non-scoped values to be mixed in - /// API calls, which is very useful (e.g. passing a scoped userdata to a non-scoped function). - /// However, this also enables handles to scoped values to be trivially leaked from the given - /// callback. This is not dangerous, though! After the callback returns, all scoped values are - /// invalidated, which means that though references may exist, the Rust types backing them have - /// dropped. `Function` types will error when called, and `AnyUserData` will be typeless. It - /// would be impossible to prevent handles to scoped values from escaping anyway, since you - /// would always be able to smuggle them through Lua state. - pub fn scope<'lua, 'scope, R, F>(&'lua self, f: F) -> Result - where - 'lua: 'scope, - R: 'static, - F: FnOnce(&Scope<'lua, 'scope>) -> Result, - { - f(&Scope::new(self)) - } - - /// An asynchronous version of [`scope`] that allows to create scoped async functions and - /// execute them. - /// - /// Requires `feature = "async"` - /// - /// [`scope`]: #method.scope - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn async_scope<'lua, 'scope, R, F, FR>( - &'lua self, - f: F, - ) -> LocalBoxFuture<'scope, Result> - where - 'lua: 'scope, - R: 'static, - F: FnOnce(Scope<'lua, 'scope>) -> FR, - FR: 'scope + Future>, - { - Box::pin(f(Scope::new(self))) - } - - /// Attempts to coerce a Lua value into a String in a manner consistent with Lua's internal - /// behavior. - /// - /// To succeed, the value must be a string (in which case this is a no-op), an integer, or a - /// number. - pub fn coerce_string<'lua>(&'lua self, v: Value<'lua>) -> Result>> { - Ok(match v { - Value::String(s) => Some(s), - v => unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 4)?; - - self.push_value(v)?; - let res = protect_lua!(self.state, 1, 1, |state| { - ffi::lua_tolstring(state, -1, ptr::null_mut()) - })?; - if !res.is_null() { - Some(String(self.pop_ref())) - } else { - None - } - }, - }) - } - - /// Attempts to coerce a Lua value into an integer in a manner consistent with Lua's internal - /// behavior. - /// - /// To succeed, the value must be an integer, a floating point number that has an exact - /// representation as an integer, or a string that can be converted to an integer. Refer to the - /// Lua manual for details. - pub fn coerce_integer(&self, v: Value) -> Result> { - Ok(match v { - Value::Integer(i) => Some(i), - v => unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - - self.push_value(v)?; - let mut isint = 0; - let i = ffi::lua_tointegerx(self.state, -1, &mut isint); - if isint == 0 { - None - } else { - Some(i) - } - }, - }) - } - - /// Attempts to coerce a Lua value into a Number in a manner consistent with Lua's internal - /// behavior. - /// - /// To succeed, the value must be a number or a string that can be converted to a number. Refer - /// to the Lua manual for details. - pub fn coerce_number(&self, v: Value) -> Result> { - Ok(match v { - Value::Number(n) => Some(n), - v => unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - - self.push_value(v)?; - let mut isnum = 0; - let n = ffi::lua_tonumberx(self.state, -1, &mut isnum); - if isnum == 0 { - None - } else { - Some(n) - } - }, - }) - } - - /// Converts a value that implements `ToLua` into a `Value` instance. - pub fn pack<'lua, T: ToLua<'lua>>(&'lua self, t: T) -> Result> { - t.to_lua(self) - } - - /// Converts a `Value` instance into a value that implements `FromLua`. - pub fn unpack<'lua, T: FromLua<'lua>>(&'lua self, value: Value<'lua>) -> Result { - T::from_lua(value, self) - } - - /// Converts a value that implements `ToLuaMulti` into a `MultiValue` instance. - pub fn pack_multi<'lua, T: ToLuaMulti<'lua>>(&'lua self, t: T) -> Result> { - t.to_lua_multi(self) - } - - /// Converts a `MultiValue` instance into a value that implements `FromLuaMulti`. - pub fn unpack_multi<'lua, T: FromLuaMulti<'lua>>( - &'lua self, - value: MultiValue<'lua>, - ) -> Result { - T::from_lua_multi(value, self) - } - - /// Set a value in the Lua registry based on a string name. - /// - /// This value will be available to rust from all `Lua` instances which share the same main - /// state. - pub fn set_named_registry_value<'lua, S, T>(&'lua self, name: &S, t: T) -> Result<()> - where - S: AsRef<[u8]> + ?Sized, - T: ToLua<'lua>, - { - let t = t.to_lua(self)?; - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 5)?; - - self.push_value(t)?; - rawset_field(self.state, ffi::LUA_REGISTRYINDEX, name) - } - } - - /// Get a value from the Lua registry based on a string name. - /// - /// Any Lua instance which shares the underlying main state may call this method to - /// get a value previously set by [`set_named_registry_value`]. - /// - /// [`set_named_registry_value`]: #method.set_named_registry_value - pub fn named_registry_value<'lua, S, T>(&'lua self, name: &S) -> Result - where - S: AsRef<[u8]> + ?Sized, - T: FromLua<'lua>, - { - let value = unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 3)?; - - push_string(self.state, name)?; - ffi::lua_rawget(self.state, ffi::LUA_REGISTRYINDEX); - - self.pop_value() - }; - T::from_lua(value, self) - } - - /// Removes a named value in the Lua registry. - /// - /// Equivalent to calling [`set_named_registry_value`] with a value of Nil. - /// - /// [`set_named_registry_value`]: #method.set_named_registry_value - pub fn unset_named_registry_value(&self, name: &S) -> Result<()> - where - S: AsRef<[u8]> + ?Sized, - { - self.set_named_registry_value(name, Nil) - } - - /// Place a value in the Lua registry with an auto-generated key. - /// - /// This value will be available to Rust from all `Lua` instances which share the same main - /// state. - /// - /// Be warned, garbage collection of values held inside the registry is not automatic, see - /// [`RegistryKey`] for more details. - /// However, dropped [`RegistryKey`]s automatically reused to store new values. - /// - /// [`RegistryKey`]: crate::RegistryKey - pub fn create_registry_value<'lua, T: ToLua<'lua>>(&'lua self, t: T) -> Result { - let t = t.to_lua(self)?; - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 4)?; - - let unref_list = (*self.extra.get()).registry_unref_list.clone(); - self.push_value(t)?; - - // Try to reuse previously allocated RegistryKey - let unref_list2 = unref_list.clone(); - let mut unref_list2 = mlua_expect!(unref_list2.lock(), "unref list poisoned"); - if let Some(registry_id) = unref_list2.as_mut().and_then(|x| x.pop()) { - // It must be safe to replace the value without triggering memory error - ffi::lua_rawseti(self.state, ffi::LUA_REGISTRYINDEX, registry_id as Integer); - return Ok(RegistryKey { - registry_id, - unref_list, - }); - } - - // Allocate a new RegistryKey - let registry_id = protect_lua!(self.state, 1, 0, |state| { - ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) - })?; - - Ok(RegistryKey { - registry_id, - unref_list, - }) - } - } - - /// Get a value from the Lua registry by its `RegistryKey` - /// - /// Any Lua instance which shares the underlying main state may call this method to get a value - /// previously placed by [`create_registry_value`]. - /// - /// [`create_registry_value`]: #method.create_registry_value - pub fn registry_value<'lua, T: FromLua<'lua>>(&'lua self, key: &RegistryKey) -> Result { - if !self.owns_registry_value(key) { - return Err(Error::MismatchedRegistryKey); - } - - let value = unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 1)?; - - ffi::lua_rawgeti( - self.state, - ffi::LUA_REGISTRYINDEX, - key.registry_id as Integer, - ); - self.pop_value() - }; - T::from_lua(value, self) - } - - /// Removes a value from the Lua registry. - /// - /// You may call this function to manually remove a value placed in the registry with - /// [`create_registry_value`]. In addition to manual `RegistryKey` removal, you can also call - /// [`expire_registry_values`] to automatically remove values from the registry whose - /// `RegistryKey`s have been dropped. - /// - /// [`create_registry_value`]: #method.create_registry_value - /// [`expire_registry_values`]: #method.expire_registry_values - pub fn remove_registry_value(&self, key: RegistryKey) -> Result<()> { - if !self.owns_registry_value(&key) { - return Err(Error::MismatchedRegistryKey); - } - unsafe { - ffi::luaL_unref(self.state, ffi::LUA_REGISTRYINDEX, key.take()); - } - Ok(()) - } - - /// Replaces a value in the Lua registry by its `RegistryKey`. - /// - /// See [`create_registry_value`] for more details. - /// - /// [`create_registry_value`]: #method.create_registry_value - pub fn replace_registry_value<'lua, T: ToLua<'lua>>( - &'lua self, - key: &RegistryKey, - t: T, - ) -> Result<()> { - if !self.owns_registry_value(key) { - return Err(Error::MismatchedRegistryKey); - } - - let t = t.to_lua(self)?; - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 2)?; - - self.push_value(t)?; - // It must be safe to replace the value without triggering memory error - ffi::lua_rawseti( - self.state, - ffi::LUA_REGISTRYINDEX, - key.registry_id as Integer, - ); - - Ok(()) - } - } - - /// Returns true if the given `RegistryKey` was created by a `Lua` which shares the underlying - /// main state with this `Lua` instance. - /// - /// Other than this, methods that accept a `RegistryKey` will return - /// `Error::MismatchedRegistryKey` if passed a `RegistryKey` that was not created with a - /// matching `Lua` state. - pub fn owns_registry_value(&self, key: &RegistryKey) -> bool { - let registry_unref_list = unsafe { &(*self.extra.get()).registry_unref_list }; - Arc::ptr_eq(&key.unref_list, registry_unref_list) - } - - /// Remove any registry values whose `RegistryKey`s have all been dropped. - /// - /// Unlike normal handle values, `RegistryKey`s do not automatically remove themselves on Drop, - /// but you can call this method to remove any unreachable registry values not manually removed - /// by `Lua::remove_registry_value`. - pub fn expire_registry_values(&self) { - unsafe { - let mut unref_list = mlua_expect!( - (*self.extra.get()).registry_unref_list.lock(), - "unref list poisoned" - ); - let unref_list = mem::replace(&mut *unref_list, Some(Vec::new())); - for id in mlua_expect!(unref_list, "unref list not set") { - ffi::luaL_unref(self.state, ffi::LUA_REGISTRYINDEX, id); - } - } - } - - /// Sets or replaces an application data object of type `T`. - /// - /// Application data could be accessed at any time by using [`Lua::app_data_ref()`] or [`Lua::app_data_mut()`] - /// methods where `T` is the data type. - /// - /// # Examples - /// - /// ``` - /// use mlua::{Lua, Result}; - /// - /// fn hello(lua: &Lua, _: ()) -> Result<()> { - /// let mut s = lua.app_data_mut::<&str>().unwrap(); - /// assert_eq!(*s, "hello"); - /// *s = "world"; - /// Ok(()) - /// } - /// - /// fn main() -> Result<()> { - /// let lua = Lua::new(); - /// lua.set_app_data("hello"); - /// lua.create_function(hello)?.call(())?; - /// let s = lua.app_data_ref::<&str>().unwrap(); - /// assert_eq!(*s, "world"); - /// Ok(()) - /// } - /// ``` - pub fn set_app_data(&self, data: T) { - let extra = unsafe { &mut (*self.extra.get()) }; - extra - .app_data - .try_borrow_mut() - .expect("cannot borrow mutably app data container") - .insert(TypeId::of::(), Box::new(data)); - } - - /// Gets a reference to an application data object stored by [`Lua::set_app_data()`] of type `T`. - pub fn app_data_ref(&self) -> Option> { - let extra = unsafe { &(*self.extra.get()) }; - let app_data = extra - .app_data - .try_borrow() - .expect("cannot borrow app data container"); - let value = app_data.get(&TypeId::of::())?.downcast_ref::()? as *const _; - Some(Ref::map(app_data, |_| unsafe { &*value })) - } - - /// Gets a mutable reference to an application data object stored by [`Lua::set_app_data()`] of type `T`. - pub fn app_data_mut(&self) -> Option> { - let extra = unsafe { &(*self.extra.get()) }; - let mut app_data = extra - .app_data - .try_borrow_mut() - .expect("cannot mutably borrow app data container"); - let value = app_data.get_mut(&TypeId::of::())?.downcast_mut::()? as *mut _; - Some(RefMut::map(app_data, |_| unsafe { &mut *value })) - } - - /// Removes an application data of type `T`. - pub fn remove_app_data(&self) -> Option { - let extra = unsafe { &mut (*self.extra.get()) }; - extra - .app_data - .try_borrow_mut() - .expect("cannot mutably borrow app data container") - .remove(&TypeId::of::()) - .and_then(|data| data.downcast().ok().map(|data| *data)) - } - - // Uses 2 stack spaces, does not call checkstack - pub(crate) unsafe fn push_value(&self, value: Value) -> Result<()> { - match value { - Value::Nil => { - ffi::lua_pushnil(self.state); - } - - Value::Boolean(b) => { - ffi::lua_pushboolean(self.state, if b { 1 } else { 0 }); - } - - Value::LightUserData(ud) => { - ffi::lua_pushlightuserdata(self.state, ud.0); - } - - Value::Integer(i) => { - ffi::lua_pushinteger(self.state, i); - } - - Value::Number(n) => { - ffi::lua_pushnumber(self.state, n); - } - - #[cfg(feature = "luau")] - Value::Vector(x, y, z) => { - ffi::lua_pushvector(self.state, x, y, z); - } - - Value::String(s) => { - self.push_ref(&s.0); - } - - Value::Table(t) => { - self.push_ref(&t.0); - } - - Value::Function(f) => { - self.push_ref(&f.0); - } - - Value::Thread(t) => { - self.push_ref(&t.0); - } - - Value::UserData(ud) => { - self.push_ref(&ud.0); - } - - Value::Error(err) => { - push_gc_userdata(self.state, WrappedFailure::Error(err))?; - } - } - - Ok(()) - } - - // Uses 2 stack spaces, does not call checkstack - pub(crate) unsafe fn pop_value(&self) -> Value { - let state = self.state; - match ffi::lua_type(state, -1) { - ffi::LUA_TNIL => { - ffi::lua_pop(state, 1); - Nil - } - - ffi::LUA_TBOOLEAN => { - let b = Value::Boolean(ffi::lua_toboolean(state, -1) != 0); - ffi::lua_pop(state, 1); - b - } - - ffi::LUA_TLIGHTUSERDATA => { - let ud = Value::LightUserData(LightUserData(ffi::lua_touserdata(state, -1))); - ffi::lua_pop(state, 1); - ud - } - - ffi::LUA_TNUMBER => { - if ffi::lua_isinteger(state, -1) != 0 { - let i = Value::Integer(ffi::lua_tointeger(state, -1)); - ffi::lua_pop(state, 1); - i - } else { - let n = Value::Number(ffi::lua_tonumber(state, -1)); - ffi::lua_pop(state, 1); - n - } - } - - #[cfg(feature = "luau")] - ffi::LUA_TVECTOR => { - let v = ffi::lua_tovector(state, -1); - mlua_debug_assert!(!v.is_null(), "vector is null"); - let vec = Value::Vector(*v, *v.add(1), *v.add(2)); - ffi::lua_pop(state, 1); - vec - } - - ffi::LUA_TSTRING => Value::String(String(self.pop_ref())), - - ffi::LUA_TTABLE => Value::Table(Table(self.pop_ref())), - - ffi::LUA_TFUNCTION => Value::Function(Function(self.pop_ref())), - - ffi::LUA_TUSERDATA => { - // We must prevent interaction with userdata types other than UserData OR a WrappedError. - // WrappedPanics are automatically resumed. - match get_gc_userdata::(state, -1).as_mut() { - Some(WrappedFailure::Error(err)) => { - let err = err.clone(); - ffi::lua_pop(state, 1); - Value::Error(err) - } - Some(WrappedFailure::Panic(panic)) => { - if let Some(panic) = panic.take() { - ffi::lua_pop(state, 1); - resume_unwind(panic); - } - // Previously resumed panic? - ffi::lua_pop(state, 1); - Nil - } - _ => Value::UserData(AnyUserData(self.pop_ref())), - } - } - - ffi::LUA_TTHREAD => Value::Thread(Thread(self.pop_ref())), - - #[cfg(feature = "luajit")] - ffi::LUA_TCDATA => { - ffi::lua_pop(state, 1); - // TODO: Fix this in a next major release - panic!("cdata objects cannot be handled by mlua yet"); - } - - _ => mlua_panic!("LUA_TNONE in pop_value"), - } - } - - // Pushes a LuaRef value onto the stack, uses 1 stack space, does not call checkstack - pub(crate) unsafe fn push_ref(&self, lref: &LuaRef) { - assert!( - Arc::ptr_eq(&lref.lua.extra, &self.extra), - "Lua instance passed Value created from a different main Lua state" - ); - let extra = &*self.extra.get(); - #[cfg(not(feature = "luau"))] - { - ffi::lua_pushvalue(extra.ref_thread, lref.index); - ffi::lua_xmove(extra.ref_thread, self.state, 1); - } - #[cfg(feature = "luau")] - ffi::lua_xpush(extra.ref_thread, self.state, lref.index); - } - - // Pops the topmost element of the stack and stores a reference to it. This pins the object, - // preventing garbage collection until the returned `LuaRef` is dropped. - // - // References are stored in the stack of a specially created auxiliary thread that exists only - // to store reference values. This is much faster than storing these in the registry, and also - // much more flexible and requires less bookkeeping than storing them directly in the currently - // used stack. The implementation is somewhat biased towards the use case of a relatively small - // number of short term references being created, and `RegistryKey` being used for long term - // references. - pub(crate) unsafe fn pop_ref(&self) -> LuaRef { - let extra = &mut *self.extra.get(); - ffi::lua_xmove(self.state, extra.ref_thread, 1); - let index = ref_stack_pop(extra); - LuaRef { lua: self, index } - } - - pub(crate) fn clone_ref<'lua>(&'lua self, lref: &LuaRef<'lua>) -> LuaRef<'lua> { - unsafe { - let extra = &mut *self.extra.get(); - ffi::lua_pushvalue(extra.ref_thread, lref.index); - let index = ref_stack_pop(extra); - LuaRef { lua: self, index } - } - } - - pub(crate) fn drop_ref(&self, lref: &LuaRef) { - unsafe { - let extra = &mut *self.extra.get(); - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, lref.index); - extra.ref_free.push(lref.index); - } - } - - /// Executes the function provided on the ref thread - #[inline] - pub(crate) unsafe fn ref_thread_exec(&self, f: F) -> R - where - F: FnOnce(*mut ffi::lua_State) -> R, - { - let ref_thread = (*self.extra.get()).ref_thread; - f(ref_thread) - } - - unsafe fn push_userdata_metatable(&self) -> Result<()> { - let extra = &mut *self.extra.get(); - - let type_id = TypeId::of::(); - if let Some(&table_id) = extra.registered_userdata.get(&type_id) { - ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, table_id as Integer); - return Ok(()); - } - - let _sg = StackGuard::new_extra(self.state, 1); - check_stack(self.state, 13)?; - - let mut fields = StaticUserDataFields::default(); - let mut methods = StaticUserDataMethods::default(); - T::add_fields(&mut fields); - T::add_methods(&mut methods); - - // Prepare metatable, add meta methods first and then meta fields - let metatable_nrec = methods.meta_methods.len() + fields.meta_fields.len(); - #[cfg(feature = "async")] - let metatable_nrec = metatable_nrec + methods.async_meta_methods.len(); - push_table(self.state, 0, metatable_nrec as c_int)?; - for (k, m) in methods.meta_methods { - self.push_value(Value::Function(self.create_callback(m)?))?; - rawset_field(self.state, -2, k.validate()?.name())?; - } - #[cfg(feature = "async")] - for (k, m) in methods.async_meta_methods { - self.push_value(Value::Function(self.create_async_callback(m)?))?; - rawset_field(self.state, -2, k.validate()?.name())?; - } - for (k, f) in fields.meta_fields { - self.push_value(f(self)?)?; - rawset_field(self.state, -2, k.validate()?.name())?; - } - let metatable_index = ffi::lua_absindex(self.state, -1); - - let mut extra_tables_count = 0; - - let mut field_getters_index = None; - let field_getters_nrec = fields.field_getters.len(); - if field_getters_nrec > 0 { - push_table(self.state, 0, field_getters_nrec as c_int)?; - for (k, m) in fields.field_getters { - self.push_value(Value::Function(self.create_callback(m)?))?; - rawset_field(self.state, -2, &k)?; - } - field_getters_index = Some(ffi::lua_absindex(self.state, -1)); - extra_tables_count += 1; - } - - let mut field_setters_index = None; - let field_setters_nrec = fields.field_setters.len(); - if field_setters_nrec > 0 { - push_table(self.state, 0, field_setters_nrec as c_int)?; - for (k, m) in fields.field_setters { - self.push_value(Value::Function(self.create_callback(m)?))?; - rawset_field(self.state, -2, &k)?; - } - field_setters_index = Some(ffi::lua_absindex(self.state, -1)); - extra_tables_count += 1; - } - - let mut methods_index = None; - let methods_nrec = methods.methods.len(); - #[cfg(feature = "async")] - let methods_nrec = methods_nrec + methods.async_methods.len(); - if methods_nrec > 0 { - push_table(self.state, 0, methods_nrec as c_int)?; - for (k, m) in methods.methods { - self.push_value(Value::Function(self.create_callback(m)?))?; - rawset_field(self.state, -2, &k)?; - } - #[cfg(feature = "async")] - for (k, m) in methods.async_methods { - self.push_value(Value::Function(self.create_async_callback(m)?))?; - rawset_field(self.state, -2, &k)?; - } - methods_index = Some(ffi::lua_absindex(self.state, -1)); - extra_tables_count += 1; - } - - init_userdata_metatable::>( - self.state, - metatable_index, - field_getters_index, - field_setters_index, - methods_index, - )?; - - // Pop extra tables to get metatable on top of the stack - ffi::lua_pop(self.state, extra_tables_count); - - let mt_ptr = ffi::lua_topointer(self.state, -1); - ffi::lua_pushvalue(self.state, -1); - let id = protect_lua!(self.state, 1, 0, |state| { - ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) - })?; - - extra.registered_userdata.insert(type_id, id); - extra.registered_userdata_mt.insert(mt_ptr, Some(type_id)); - - Ok(()) - } - - pub(crate) unsafe fn register_userdata_metatable( - &self, - ptr: *const c_void, - type_id: Option, - ) { - let extra = &mut *self.extra.get(); - extra.registered_userdata_mt.insert(ptr, type_id); - } - - pub(crate) unsafe fn deregister_userdata_metatable(&self, ptr: *const c_void) { - (*self.extra.get()).registered_userdata_mt.remove(&ptr); - } - - // Pushes a LuaRef value onto the stack, checking that it's a registered - // and not destructed UserData. - // Uses 2 stack spaces, does not call checkstack. - pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result> { - self.push_ref(lref); - if ffi::lua_getmetatable(self.state, -1) == 0 { - return Err(Error::UserDataTypeMismatch); - } - let mt_ptr = ffi::lua_topointer(self.state, -1); - ffi::lua_pop(self.state, 1); - - let extra = &*self.extra.get(); - match extra.registered_userdata_mt.get(&mt_ptr) { - Some(&type_id) if type_id == Some(TypeId::of::()) => { - Err(Error::UserDataDestructed) - } - Some(&type_id) => Ok(type_id), - None => Err(Error::UserDataTypeMismatch), - } - } - - // Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the - // Fn is 'static, otherwise it could capture 'lua arguments improperly. Without ATCs, we - // cannot easily deal with the "correct" callback type of: - // - // Box Fn(&'lua Lua, MultiValue<'lua>) -> Result>)> - // - // So we instead use a caller provided lifetime, which without the 'static requirement would be - // unsafe. - pub(crate) fn create_callback<'lua>( - &'lua self, - func: Callback<'lua, 'static>, - ) -> Result> { - unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { - ffi::LUA_TUSERDATA => { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).extra.get() - } - _ => ptr::null_mut(), - }; - callback_error_ext(state, extra, |nargs| { - let upvalue_idx = ffi::lua_upvalueindex(1); - if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL { - return Err(Error::CallbackDestructed); - } - let upvalue = get_userdata::(state, upvalue_idx); - - if nargs < ffi::LUA_MINSTACK { - check_stack(state, ffi::LUA_MINSTACK - nargs)?; - } - - let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); - let _guard = StateGuard::new(&mut *lua.0.get(), state); - - let mut args = MultiValue::new_or_cached(lua); - args.reserve(nargs as usize); - for _ in 0..nargs { - args.push_front(lua.pop_value()); - } - - let func = &*(*upvalue).data; - let mut results = func(lua, args)?; - let nresults = results.len() as c_int; - - check_stack(state, nresults)?; - for r in results.drain_all() { - lua.push_value(r)?; - } - lua.cache_multivalue(results); - - Ok(nresults) - }) - } - - unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 4)?; - - let func = mem::transmute(func); - let extra = Arc::clone(&self.extra); - push_gc_userdata(self.state, CallbackUpvalue { data: func, extra })?; - protect_lua!(self.state, 1, 1, fn(state) { - ffi::lua_pushcclosure(state, call_callback, 1); - })?; - - Ok(Function(self.pop_ref())) - } - } - - #[cfg(feature = "async")] - pub(crate) fn create_async_callback<'lua>( - &'lua self, - func: AsyncCallback<'lua, 'static>, - ) -> Result> { - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luau" - ))] - unsafe { - let libs = (*self.extra.get()).libs; - if !libs.contains(StdLib::COROUTINE) { - self.load_from_std_lib(StdLib::COROUTINE)?; - } - } - - unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { - let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { - ffi::LUA_TUSERDATA => { - let upvalue = - get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).extra.get() - } - _ => ptr::null_mut(), - }; - callback_error_ext(state, extra, |nargs| { - let upvalue_idx = ffi::lua_upvalueindex(1); - if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL { - return Err(Error::CallbackDestructed); - } - let upvalue = get_userdata::(state, upvalue_idx); - - if nargs < ffi::LUA_MINSTACK { - check_stack(state, ffi::LUA_MINSTACK - nargs)?; - } - - let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); - let _guard = StateGuard::new(&mut *lua.0.get(), state); - - let mut args = MultiValue::new_or_cached(lua); - args.reserve(nargs as usize); - for _ in 0..nargs { - args.push_front(lua.pop_value()); - } - - let func = &*(*upvalue).data; - let fut = func(lua, args); - let extra = Arc::clone(&(*upvalue).extra); - push_gc_userdata(state, AsyncPollUpvalue { data: fut, extra })?; - protect_lua!(state, 1, 1, fn(state) { - ffi::lua_pushcclosure(state, poll_future, 1); - })?; - - Ok(1) - }) - } - - unsafe extern "C" fn poll_future(state: *mut ffi::lua_State) -> c_int { - let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { - ffi::LUA_TUSERDATA => { - let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).extra.get() - } - _ => ptr::null_mut(), - }; - callback_error_ext(state, extra, |nargs| { - let upvalue_idx = ffi::lua_upvalueindex(1); - if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL { - return Err(Error::CallbackDestructed); - } - let upvalue = get_userdata::(state, upvalue_idx); - - if nargs < ffi::LUA_MINSTACK { - check_stack(state, ffi::LUA_MINSTACK - nargs)?; - } - - let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); - let _guard = StateGuard::new(&mut *lua.0.get(), state); - - // Try to get an outer poll waker - let waker = lua.waker().unwrap_or_else(noop_waker); - let mut ctx = Context::from_waker(&waker); - - let fut = &mut (*upvalue).data; - match fut.as_mut().poll(&mut ctx) { - Poll::Pending => { - check_stack(state, 1)?; - ffi::lua_pushboolean(state, 0); - Ok(1) - } - Poll::Ready(results) => { - let results = results?; - let nresults = results.len() as Integer; - let results = lua.create_sequence_from(results)?; - check_stack(state, 3)?; - ffi::lua_pushboolean(state, 1); - lua.push_value(Value::Table(results))?; - lua.push_value(Value::Integer(nresults))?; - Ok(3) - } - } - }) - } - - let get_poll = unsafe { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 4)?; - - let func = mem::transmute(func); - let extra = Arc::clone(&self.extra); - push_gc_userdata(self.state, AsyncCallbackUpvalue { data: func, extra })?; - protect_lua!(self.state, 1, 1, fn(state) { - ffi::lua_pushcclosure(state, call_callback, 1); - })?; - - Function(self.pop_ref()) - }; - - unsafe extern "C" fn unpack(state: *mut ffi::lua_State) -> c_int { - let len = ffi::lua_tointeger(state, 2); - ffi::luaL_checkstack(state, len as c_int, ptr::null()); - for i in 1..=len { - ffi::lua_rawgeti(state, 1, i); - } - len as c_int - } - - let coroutine = self.globals().get::<_, Table>("coroutine")?; - - let env = self.create_table_with_capacity(0, 4)?; - env.set("get_poll", get_poll)?; - env.set("yield", coroutine.get::<_, Function>("yield")?)?; - unsafe { - env.set("unpack", self.create_c_function(unpack)?)?; - } - env.set("pending", { - LightUserData(&ASYNC_POLL_PENDING as *const u8 as *mut c_void) - })?; - - // We set `poll` variable in the env table to be able to destroy upvalues - self.load( - r#" - poll = get_poll(...) - local poll, pending, yield, unpack = poll, pending, yield, unpack - while true do - local ready, res, nres = poll() - if ready then - return unpack(res, nres) - end - yield(pending) - end - "#, - ) - .try_cache() - .set_name("_mlua_async_poll")? - .set_environment(env)? - .into_function() - } - - #[cfg(feature = "async")] - #[inline] - pub(crate) unsafe fn waker(&self) -> Option { - let extra = &*self.extra.get(); - (*get_userdata::>(extra.ref_thread, extra.ref_waker_idx)).clone() - } - - #[cfg(feature = "async")] - #[inline] - pub(crate) unsafe fn set_waker(&self, waker: Option) -> Option { - let extra = &*self.extra.get(); - let waker_slot = &mut *get_userdata::>(extra.ref_thread, extra.ref_waker_idx); - match waker { - Some(waker) => waker_slot.replace(waker), - None => waker_slot.take(), - } - } - - pub(crate) unsafe fn make_userdata(&self, data: UserDataCell) -> Result - where - T: 'static + UserData, - { - let _sg = StackGuard::new(self.state); - check_stack(self.state, 3)?; - - // We push metatable first to ensure having correct metatable with `__gc` method - ffi::lua_pushnil(self.state); - self.push_userdata_metatable::()?; - #[cfg(not(feature = "lua54"))] - push_userdata(self.state, data)?; - #[cfg(feature = "lua54")] - push_userdata_uv(self.state, data, USER_VALUE_MAXSLOT as c_int)?; - ffi::lua_replace(self.state, -3); - ffi::lua_setmetatable(self.state, -2); - - // Set empty environment for Lua 5.1 - #[cfg(any(feature = "lua51", feature = "luajit"))] - protect_lua!(self.state, 1, 1, fn(state) { - ffi::lua_newtable(state); - ffi::lua_setuservalue(state, -2); - })?; - - Ok(AnyUserData(self.pop_ref())) - } - - #[cfg(not(feature = "luau"))] - fn disable_c_modules(&self) -> Result<()> { - let package: Table = self.globals().get("package")?; - - package.set( - "loadlib", - self.create_function(|_, ()| -> Result<()> { - Err(Error::SafetyError( - "package.loadlib is disabled in safe mode".to_string(), - )) - })?, - )?; - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - let searchers: Table = package.get("searchers")?; - #[cfg(any(feature = "lua51", feature = "luajit"))] - let searchers: Table = package.get("loaders")?; - - let loader = self.create_function(|_, ()| Ok("\n\tcan't load C modules in safe mode"))?; - - // The third and fourth searchers looks for a loader as a C library - searchers.raw_set(3, loader.clone())?; - searchers.raw_remove(4)?; - - Ok(()) - } - - pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Option { - let _sg = StackGuard::new(state); - assert_stack(state, 1); - let extra = extra_data(state)?; - let inner = &*(*extra.get()).inner.as_ref().unwrap(); - Some(Lua(Arc::clone(inner))) - } - - #[inline] - pub(crate) fn new_or_cached_multivalue(&self) -> MultiValue { - unsafe { - let extra = &mut *self.extra.get(); - extra.multivalue_cache.pop().unwrap_or_default() - } - } - - #[inline] - pub(crate) fn cache_multivalue(&self, mut multivalue: MultiValue) { - unsafe { - let extra = &mut *self.extra.get(); - if extra.multivalue_cache.len() < MULTIVALUE_CACHE_SIZE { - multivalue.clear(); - extra.multivalue_cache.push(mem::transmute(multivalue)); - } - } - } -} - -struct StateGuard<'a>(&'a mut LuaInner, *mut ffi::lua_State); - -impl<'a> StateGuard<'a> { - fn new(inner: &'a mut LuaInner, mut state: *mut ffi::lua_State) -> Self { - mem::swap(&mut (*inner).state, &mut state); - Self(inner, state) - } -} - -impl<'a> Drop for StateGuard<'a> { - fn drop(&mut self) { - mem::swap(&mut (*self.0).state, &mut self.1); - } -} - -#[cfg(feature = "luau")] -unsafe fn extra_data(state: *mut ffi::lua_State) -> Option>> { - let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc>; - if extra_ptr.is_null() { - return None; - } - Some(Arc::clone(&*extra_ptr)) -} - -#[cfg(not(feature = "luau"))] -unsafe fn extra_data(state: *mut ffi::lua_State) -> Option>> { - let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { - return None; - } - let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Arc>; - let extra = Arc::clone(&*extra_ptr); - ffi::lua_pop(state, 1); - Some(extra) -} - -// Creates required entries in the metatable cache (see `util::METATABLE_CACHE`) -pub(crate) fn init_metatable_cache(cache: &mut FxHashMap) { - cache.insert(TypeId::of::>>(), 0); - cache.insert(TypeId::of::(), 0); - cache.insert(TypeId::of::(), 0); - - #[cfg(feature = "async")] - { - cache.insert(TypeId::of::(), 0); - cache.insert(TypeId::of::(), 0); - cache.insert(TypeId::of::(), 0); - cache.insert(TypeId::of::>(), 0); - } -} - -// An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata -// and instead reuses unsed and cached values from previous calls (or allocates new). -unsafe fn callback_error_ext(state: *mut ffi::lua_State, extra: *mut ExtraData, f: F) -> R -where - F: FnOnce(c_int) -> Result, -{ - if extra.is_null() { - return callback_error(state, f); - } - let extra = &mut *extra; - - let nargs = ffi::lua_gettop(state); - - // We need 2 extra stack spaces to store userdata and error/panic metatable. - // Luau workaround can be removed after solving https://github.com/Roblox/luau/issues/446 - // Also see #142 and #153 - if !cfg!(feature = "luau") || extra.wrapped_failures_cache.is_empty() { - let extra_stack = if nargs < 2 { 2 - nargs } else { 1 }; - ffi::luaL_checkstack( - state, - extra_stack, - cstr!("not enough stack space for callback error handling"), - ); - } - - enum PreallocatedFailure { - New(*mut WrappedFailure), - Cached(i32), - } - - // We cannot shadow Rust errors with Lua ones, so we need to obtain pre-allocated memory - // to store a wrapped failure (error or panic) *before* we proceed. - let prealloc_failure = match extra.wrapped_failures_cache.pop() { - Some(index) => PreallocatedFailure::Cached(index), - None => { - let ud = WrappedFailure::new_userdata(state); - ffi::lua_rotate(state, 1, 1); - PreallocatedFailure::New(ud) - } - }; - - let mut get_wrapped_failure = || match prealloc_failure { - PreallocatedFailure::New(ud) => { - ffi::lua_settop(state, 1); - ud - } - PreallocatedFailure::Cached(index) => { - ffi::lua_settop(state, 0); - #[cfg(feature = "luau")] - assert_stack(state, 2); - ffi::lua_pushvalue(extra.ref_thread, index); - ffi::lua_xmove(extra.ref_thread, state, 1); - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, index); - extra.ref_free.push(index); - ffi::lua_touserdata(state, -1) as *mut WrappedFailure - } - }; - - match catch_unwind(AssertUnwindSafe(|| f(nargs))) { - Ok(Ok(r)) => { - // Return unused WrappedFailure to the cache - match prealloc_failure { - PreallocatedFailure::New(_) - if extra.wrapped_failures_cache.len() < WRAPPED_FAILURES_CACHE_SIZE => - { - ffi::lua_rotate(state, 1, -1); - ffi::lua_xmove(state, extra.ref_thread, 1); - let index = ref_stack_pop(extra); - extra.wrapped_failures_cache.push(index); - } - PreallocatedFailure::New(_) => { - ffi::lua_remove(state, 1); - } - PreallocatedFailure::Cached(index) - if extra.wrapped_failures_cache.len() < WRAPPED_FAILURES_CACHE_SIZE => - { - extra.wrapped_failures_cache.push(index); - } - PreallocatedFailure::Cached(index) => { - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, index); - extra.ref_free.push(index); - } - } - r - } - Ok(Err(err)) => { - let wrapped_error = get_wrapped_failure(); - - // Build `CallbackError` with traceback - let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { - ffi::luaL_traceback(state, state, ptr::null(), 0); - let traceback = util::to_string(state, -1); - ffi::lua_pop(state, 1); - traceback - } else { - "".to_string() - }; - let cause = Arc::new(err); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::CallbackError { traceback, cause }), - ); - get_gc_metatable::(state); - ffi::lua_setmetatable(state, -2); - - ffi::lua_error(state) - } - Err(p) => { - let wrapped_panic = get_wrapped_failure(); - ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); - get_gc_metatable::(state); - ffi::lua_setmetatable(state, -2); - ffi::lua_error(state) - } - } -} - -// Uses 3 stack spaces -unsafe fn load_from_std_lib(state: *mut ffi::lua_State, libs: StdLib) -> Result<()> { - #[inline(always)] - pub unsafe fn requiref + ?Sized>( - state: *mut ffi::lua_State, - modname: &S, - openf: ffi::lua_CFunction, - glb: c_int, - ) -> Result<()> { - let modname = mlua_expect!(CString::new(modname.as_ref()), "modname contains nil byte"); - protect_lua!(state, 0, 1, |state| { - ffi::luaL_requiref(state, modname.as_ptr() as *const c_char, openf, glb) - }) - } - - #[cfg(feature = "luajit")] - struct GcGuard(*mut ffi::lua_State); - - #[cfg(feature = "luajit")] - impl GcGuard { - fn new(state: *mut ffi::lua_State) -> Self { - // Stop collector during library initialization - unsafe { ffi::lua_gc(state, ffi::LUA_GCSTOP, 0) }; - GcGuard(state) - } - } - - #[cfg(feature = "luajit")] - impl Drop for GcGuard { - fn drop(&mut self) { - unsafe { ffi::lua_gc(self.0, ffi::LUA_GCRESTART, -1) }; - } - } - - // Stop collector during library initialization - #[cfg(feature = "luajit")] - let _gc_guard = GcGuard::new(state); - - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luau" - ))] - { - if libs.contains(StdLib::COROUTINE) { - requiref(state, ffi::LUA_COLIBNAME, ffi::luaopen_coroutine, 1)?; - ffi::lua_pop(state, 1); - } - } - - if libs.contains(StdLib::TABLE) { - requiref(state, ffi::LUA_TABLIBNAME, ffi::luaopen_table, 1)?; - ffi::lua_pop(state, 1); - } - - #[cfg(not(feature = "luau"))] - if libs.contains(StdLib::IO) { - requiref(state, ffi::LUA_IOLIBNAME, ffi::luaopen_io, 1)?; - ffi::lua_pop(state, 1); - } - - if libs.contains(StdLib::OS) { - requiref(state, ffi::LUA_OSLIBNAME, ffi::luaopen_os, 1)?; - ffi::lua_pop(state, 1); - } - - if libs.contains(StdLib::STRING) { - requiref(state, ffi::LUA_STRLIBNAME, ffi::luaopen_string, 1)?; - ffi::lua_pop(state, 1); - } - - #[cfg(any(feature = "lua54", feature = "lua53", feature = "luau"))] - { - if libs.contains(StdLib::UTF8) { - requiref(state, ffi::LUA_UTF8LIBNAME, ffi::luaopen_utf8, 1)?; - ffi::lua_pop(state, 1); - } - } - - #[cfg(any(feature = "lua52", feature = "luau"))] - { - if libs.contains(StdLib::BIT) { - requiref(state, ffi::LUA_BITLIBNAME, ffi::luaopen_bit32, 1)?; - ffi::lua_pop(state, 1); - } - } - - #[cfg(feature = "luajit")] - { - if libs.contains(StdLib::BIT) { - requiref(state, ffi::LUA_BITLIBNAME, ffi::luaopen_bit, 1)?; - ffi::lua_pop(state, 1); - } - } - - if libs.contains(StdLib::MATH) { - requiref(state, ffi::LUA_MATHLIBNAME, ffi::luaopen_math, 1)?; - ffi::lua_pop(state, 1); - } - - if libs.contains(StdLib::DEBUG) { - requiref(state, ffi::LUA_DBLIBNAME, ffi::luaopen_debug, 1)?; - ffi::lua_pop(state, 1); - } - - #[cfg(not(feature = "luau"))] - if libs.contains(StdLib::PACKAGE) { - requiref(state, ffi::LUA_LOADLIBNAME, ffi::luaopen_package, 1)?; - ffi::lua_pop(state, 1); - } - - #[cfg(feature = "luajit")] - { - if libs.contains(StdLib::JIT) { - requiref(state, ffi::LUA_JITLIBNAME, ffi::luaopen_jit, 1)?; - ffi::lua_pop(state, 1); - } - - if libs.contains(StdLib::FFI) { - requiref(state, ffi::LUA_FFILIBNAME, ffi::luaopen_ffi, 1)?; - ffi::lua_pop(state, 1); - } - } - - Ok(()) -} - -unsafe fn ref_stack_pop(extra: &mut ExtraData) -> c_int { - if let Some(free) = extra.ref_free.pop() { - ffi::lua_replace(extra.ref_thread, free); - return free; - } - - // Try to grow max stack size - if extra.ref_stack_top >= extra.ref_stack_size { - let mut inc = extra.ref_stack_size; // Try to double stack size - while inc > 0 && ffi::lua_checkstack(extra.ref_thread, inc) == 0 { - inc /= 2; - } - if inc == 0 { - // Pop item on top of the stack to avoid stack leaking and successfully run destructors - // during unwinding. - ffi::lua_pop(extra.ref_thread, 1); - let top = extra.ref_stack_top; - // It is a user error to create enough references to exhaust the Lua max stack size for - // the ref thread. - panic!( - "cannot create a Lua reference, out of auxiliary stack space (used {} slots)", - top - ); - } - extra.ref_stack_size += inc; - } - extra.ref_stack_top += 1; - extra.ref_stack_top -} diff --git a/src/luau.rs b/src/luau.rs deleted file mode 100644 index b9b2adfa..00000000 --- a/src/luau.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::ffi::CStr; -use std::os::raw::{c_float, c_int}; - -use crate::chunk::ChunkMode; -use crate::error::{Error, Result}; -use crate::ffi; -use crate::lua::Lua; -use crate::table::Table; -use crate::util::{check_stack, StackGuard}; -use crate::value::Value; - -// Since Luau has some missing standard function, we re-implement them here - -impl Lua { - pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> { - let globals = self.globals(); - - globals.raw_set( - "collectgarbage", - self.create_c_function(lua_collectgarbage)?, - )?; - globals.raw_set("require", self.create_function(lua_require)?)?; - globals.raw_set("vector", self.create_c_function(lua_vector)?)?; - - Ok(()) - } -} - -unsafe extern "C" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int { - let option = ffi::luaL_optstring(state, 1, cstr!("collect")); - let option = CStr::from_ptr(option); - let arg = ffi::luaL_optinteger(state, 2, 0); - match option.to_str() { - Ok("collect") => { - ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); - 0 - } - Ok("stop") => { - ffi::lua_gc(state, ffi::LUA_GCSTOP, 0); - 0 - } - Ok("restart") => { - ffi::lua_gc(state, ffi::LUA_GCRESTART, 0); - 0 - } - Ok("count") => { - let kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0) as ffi::lua_Number; - let kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0) as ffi::lua_Number; - ffi::lua_pushnumber(state, kbytes + kbytes_rem / 1024.0); - 1 - } - Ok("step") => { - let res = ffi::lua_gc(state, ffi::LUA_GCSTEP, arg); - ffi::lua_pushboolean(state, res); - 1 - } - Ok("isrunning") => { - let res = ffi::lua_gc(state, ffi::LUA_GCISRUNNING, 0); - ffi::lua_pushboolean(state, res); - 1 - } - _ => ffi::luaL_error(state, cstr!("collectgarbage called with invalid option")), - } -} - -fn lua_require(lua: &Lua, name: Option) -> Result { - let name = name.ok_or_else(|| Error::RuntimeError("invalid module name".into()))?; - - // Find module in the cache - let loaded = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - protect_lua!(lua.state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(lua.pop_ref()) - }; - if let Some(v) = loaded.raw_get(name.clone())? { - return Ok(v); - } - - // Load file from filesystem - let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); - if search_path.is_empty() { - search_path = "?.luau;?.lua".into(); - } - - let mut source = None; - for path in search_path.split(';') { - if let Ok(buf) = std::fs::read(path.replacen('?', &name, 1)) { - source = Some(buf); - break; - } - } - let source = source.ok_or_else(|| Error::RuntimeError(format!("cannot find '{}'", name)))?; - - let value = lua - .load(&source) - .set_name(&format!("={}", name))? - .set_mode(ChunkMode::Text) - .call::<_, Value>(())?; - - // Save in the cache - loaded.raw_set( - name, - match value.clone() { - Value::Nil => Value::Boolean(true), - v => v, - }, - )?; - - Ok(value) -} - -// Luau vector datatype constructor -unsafe extern "C" fn lua_vector(state: *mut ffi::lua_State) -> c_int { - let x = ffi::luaL_checknumber(state, 1) as c_float; - let y = ffi::luaL_checknumber(state, 2) as c_float; - let z = ffi::luaL_checknumber(state, 3) as c_float; - ffi::lua_pushvector(state, x, y, z); - 1 -} diff --git a/src/luau/heap_dump.rs b/src/luau/heap_dump.rs new file mode 100644 index 00000000..0189845c --- /dev/null +++ b/src/luau/heap_dump.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; +use std::hash::Hash; +use std::mem; +use std::os::raw::c_char; + +use crate::state::ExtraData; + +use super::json::{self, Json}; + +/// Represents a heap dump of a Luau memory state. +#[cfg(any(feature = "luau", doc))] +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +pub struct HeapDump { + data: Json<'static>, // refers to the contents of `buf` + buf: Box, +} + +impl HeapDump { + /// Dumps the current Lua heap state. + pub(crate) unsafe fn new(state: *mut ffi::lua_State) -> Option { + unsafe extern "C" fn category_name(state: *mut ffi::lua_State, cat: u8) -> *const c_char { + (&*ExtraData::get(state)) + .mem_categories + .get(cat as usize) + .map(|s| s.as_ptr()) + .unwrap_or(cstr!("unknown")) + } + + let mut buf = Vec::new(); + unsafe { + let file = libc::tmpfile(); + if file.is_null() { + return None; + } + ffi::lua_gcdump(state, file as *mut _, Some(category_name)); + libc::fseek(file, 0, libc::SEEK_END); + let len = libc::ftell(file) as usize; + libc::rewind(file); + if len > 0 { + buf.reserve(len); + libc::fread(buf.as_mut_ptr() as *mut _, 1, len, file); + buf.set_len(len); + } + libc::fclose(file); + } + + let buf = String::from_utf8(buf).ok()?.into_boxed_str(); + let data = json::parse(unsafe { mem::transmute::<&str, &'static str>(&buf) }).ok()?; + Some(HeapDump { data, buf }) + } + + /// Returns the raw JSON representation of the heap dump. + /// + /// The JSON structure is an internal detail and may change in future versions. + #[doc(hidden)] + pub fn to_json(&self) -> &str { + &self.buf + } + + /// Returns the total size of the Lua heap in bytes. + pub fn size(&self) -> u64 { + self.data["stats"]["size"].as_u64().unwrap_or_default() + } + + /// Returns a mapping from object type to (count, total size in bytes). + /// + /// If `category` is provided, only objects in that category are considered. + pub fn size_by_type<'a>(&'a self, category: Option<&str>) -> HashMap<&'a str, (usize, u64)> { + self.size_by_type_inner(category).unwrap_or_default() + } + + fn size_by_type_inner<'a>(&'a self, category: Option<&str>) -> Option> { + let category_id = match category { + // If we cannot find the category, return empty result + Some(cat) => Some(self.find_category_id(cat)?), + None => None, + }; + + let mut size_by_type = HashMap::new(); + let objects = self.data["objects"].as_object()?; + for obj in objects.values() { + if let Some(cat_id) = category_id + && obj["cat"].as_i64()? != cat_id + { + continue; + } + update_size(&mut size_by_type, obj["type"].as_str()?, obj["size"].as_u64()?); + } + Some(size_by_type) + } + + /// Returns a mapping from category name to total size in bytes. + pub fn size_by_category(&self) -> HashMap<&str, u64> { + let mut size_by_category = HashMap::new(); + if let Some(categories) = self.data["stats"]["categories"].as_object() { + for cat in categories.values() { + if let Some(cat_name) = cat["name"].as_str() { + size_by_category.insert(cat_name, cat["size"].as_u64().unwrap_or_default()); + } + } + } + size_by_category + } + + /// Returns a mapping from userdata type to (count, total size in bytes). + pub fn size_by_userdata<'a>(&'a self, category: Option<&str>) -> HashMap<&'a str, (usize, u64)> { + self.size_by_userdata_inner(category).unwrap_or_default() + } + + fn size_by_userdata_inner<'a>( + &'a self, + category: Option<&str>, + ) -> Option> { + let category_id = match category { + // If we cannot find the category, return empty result + Some(cat) => Some(self.find_category_id(cat)?), + None => None, + }; + + let mut size_by_userdata = HashMap::new(); + let objects = self.data["objects"].as_object()?; + for obj in objects.values() { + if obj["type"] != "userdata" { + continue; + } + if let Some(cat_id) = category_id + && obj["cat"].as_i64()? != cat_id + { + continue; + } + + // Determine userdata type from metatable + let mut ud_type = "unknown"; + if let Some(metatable_addr) = obj["metatable"].as_str() + && let Some(t) = get_key(objects, &objects[metatable_addr], "__type") + { + ud_type = t; + } + update_size(&mut size_by_userdata, ud_type, obj["size"].as_u64()?); + } + Some(size_by_userdata) + } + + /// Finds the category ID for a given category name. + fn find_category_id(&self, category: &str) -> Option { + let categories = self.data["stats"]["categories"].as_object()?; + for (cat_id, cat) in categories { + if cat["name"].as_str() == Some(category) { + return cat_id.parse().ok(); + } + } + None + } +} + +/// Updates the size mapping for a given key. +fn update_size(size_type: &mut HashMap, key: K, size: u64) { + let (count, total_size) = size_type.entry(key).or_insert((0, 0)); + *count += 1; + *total_size += size; +} + +/// Retrieves the value associated with a given `key` from a Lua table `tbl`. +fn get_key<'a>(objects: &'a HashMap<&'a str, Json>, tbl: &Json, key: &str) -> Option<&'a str> { + let pairs = tbl["pairs"].as_array()?; + for kv in pairs.chunks_exact(2) { + #[rustfmt::skip] + let (Some(key_addr), Some(val_addr)) = (kv[0].as_str(), kv[1].as_str()) else { continue; }; + if objects[key_addr]["type"] == "string" && objects[key_addr]["data"].as_str() == Some(key) { + if objects[val_addr]["type"] == "string" { + return objects[val_addr]["data"].as_str(); + } else { + break; + } + } + } + None +} diff --git a/src/luau/json.rs b/src/luau/json.rs new file mode 100644 index 00000000..ce17a20e --- /dev/null +++ b/src/luau/json.rs @@ -0,0 +1,327 @@ +use std::array; +use std::collections::HashMap; +use std::iter::Peekable; +use std::ops::Index; +use std::str::CharIndices; + +// A simple JSON parser and representation. +// This parser supports only a subset of JSON specification and is intended for Luau's use cases. + +#[derive(Debug, PartialEq)] +pub(crate) enum Json<'a> { + Null, + Bool(bool), + Integer(i64), + Number(f64), + String(&'a str), + Array(Vec>), + Object(HashMap<&'a str, Json<'a>>), +} + +impl<'a> Index<&str> for Json<'a> { + type Output = Json<'a>; + + fn index(&self, key: &str) -> &Self::Output { + match self { + Json::Object(map) => map.get(key).unwrap_or(&Json::Null), + _ => &Json::Null, + } + } +} + +impl PartialEq<&str> for Json<'_> { + fn eq(&self, other: &&str) -> bool { + matches!(self, Json::String(s) if s == other) + } +} + +impl<'a> Json<'a> { + pub(crate) fn as_str(&self) -> Option<&'a str> { + match self { + Json::String(s) => Some(s), + _ => None, + } + } + + pub(crate) fn as_i64(&self) -> Option { + match self { + Json::Integer(i) => Some(*i), + Json::Number(n) if n.fract() == 0.0 => Some(*n as i64), + _ => None, + } + } + + pub(crate) fn as_u64(&self) -> Option { + self.as_i64() + .and_then(|i| if i >= 0 { Some(i as u64) } else { None }) + } + + pub(crate) fn as_array(&self) -> Option<&[Json<'a>]> { + match self { + Json::Array(arr) => Some(arr), + _ => None, + } + } + + pub(crate) fn as_object(&self) -> Option<&HashMap<&'a str, Json<'a>>> { + match self { + Json::Object(map) => Some(map), + _ => None, + } + } +} + +pub(crate) fn parse<'a>(s: &'a str) -> Result, &'static str> { + let s = s.trim_ascii(); + let mut chars = s.char_indices().peekable(); + let value = parse_value(s, &mut chars)?; + Ok(value) +} + +fn parse_value<'a>(s: &'a str, chars: &mut Peekable) -> Result, &'static str> { + skip_whitespace(chars); + match chars.peek() { + Some((_, '{')) => parse_object(s, chars), + Some((_, '[')) => parse_array(s, chars), + Some((_, '"')) => parse_string(s, chars).map(Json::String), + Some((_, 't' | 'f')) => parse_bool(chars), + Some((_, 'n')) => parse_null(chars), + Some((_, '-' | '0'..='9')) => parse_number(chars), + Some(_) => Err("unexpected character"), + None => Err("unexpected end of input"), + } +} + +fn parse_object<'a>(s: &'a str, chars: &mut Peekable) -> Result, &'static str> { + chars.next(); // consume '{' + + let mut map = HashMap::new(); + skip_whitespace(chars); + if matches!(chars.peek(), Some((_, '}'))) { + chars.next(); + return Ok(Json::Object(map)); + } + loop { + skip_whitespace(chars); + let key = parse_string(s, chars)?; + skip_whitespace(chars); + if !matches!(chars.next(), Some((_, ':'))) { + return Err("expected ':'"); + } + let value = parse_value(s, chars)?; + map.insert(key, value); + skip_whitespace(chars); + match chars.next() { + Some((_, ',')) => continue, + Some((_, '}')) => break, + _ => return Err("expected ',' or '}'"), + } + } + Ok(Json::Object(map)) +} + +fn parse_array<'a>(s: &'a str, chars: &mut Peekable) -> Result, &'static str> { + chars.next(); // consume '[' + + let mut arr = Vec::new(); + skip_whitespace(chars); + if matches!(chars.peek(), Some((_, ']'))) { + chars.next(); + return Ok(Json::Array(arr)); + } + loop { + skip_whitespace(chars); + arr.push(parse_value(s, chars)?); + skip_whitespace(chars); + match chars.next() { + Some((_, ',')) => continue, + Some((_, ']')) => return Ok(Json::Array(arr)), + _ => return Err("expected ',' or ']'"), + } + } +} + +fn parse_string<'a>(s: &'a str, chars: &mut Peekable) -> Result<&'a str, &'static str> { + if !matches!(chars.next(), Some((_, '"'))) { + return Err("expected string starting with '\"'"); + } + let start = chars.peek().map(|(i, _)| *i).unwrap_or(0); + for (i, c) in chars { + if c == '"' { + return Ok(&s[start..i]); + } + } + Err("unterminated string") +} + +fn parse_number(chars: &mut Peekable) -> Result, &'static str> { + let mut is_float = false; + let mut num = String::new(); + while let Some((_, c @ ('0'..='9' | '-' | '.' | 'e' | 'E' | '+'))) = chars.peek() { + num.push(*c); + is_float = is_float || matches!(c, '.' | 'e' | 'E'); + chars.next(); + } + if !is_float { + let i = num.parse::().map_err(|_| "invalid integer")?; + return Ok(Json::Integer(i)); + } + let n = num.parse::().map_err(|_| "invalid number")?; + Ok(Json::Number(n)) +} + +fn parse_bool(chars: &mut Peekable) -> Result, &'static str> { + let bool = next_chars(chars); + if bool == [Some('t'), Some('r'), Some('u'), Some('e')] { + return Ok(Json::Bool(true)); + } + if bool == [Some('f'), Some('a'), Some('l'), Some('s')] && matches!(chars.next(), Some((_, 'e'))) { + return Ok(Json::Bool(false)); + } + Err("invalid boolean literal") +} + +fn parse_null(chars: &mut Peekable) -> Result, &'static str> { + if next_chars(chars) == [Some('n'), Some('u'), Some('l'), Some('l')] { + return Ok(Json::Null); + } + Err("invalid \"null\" literal") +} + +fn skip_whitespace(chars: &mut Peekable) { + while let Some((_, ' ' | '\n' | '\r' | '\t')) = chars.peek() { + chars.next(); + } +} + +fn next_chars(chars: &mut Peekable) -> [Option; N] { + array::from_fn(|_| chars.next().map(|(_, c)| c)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse() { + assert_eq!(parse("null").unwrap(), Json::Null); + assert_eq!(parse("true").unwrap(), Json::Bool(true)); + assert_eq!(parse("false").unwrap(), Json::Bool(false)); + assert_eq!(parse("42").unwrap(), Json::Integer(42)); + assert_eq!(parse("42.0").unwrap(), Json::Number(42.0)); + assert_eq!(parse(r#""hello""#).unwrap(), Json::String("hello")); + assert_eq!( + parse("[1,2.0,3]").unwrap(), + Json::Array(vec![Json::Integer(1), Json::Number(2.0), Json::Integer(3)]) + ); + let mut obj = HashMap::new(); + obj.insert("key", Json::String("value")); + assert_eq!(parse(r#"{"key":"value"}"#).unwrap(), Json::Object(obj)); + } + + #[test] + fn test_whitespace_handling() { + assert_eq!(parse(" null ").unwrap(), Json::Null); + assert_eq!(parse(" true ").unwrap(), Json::Bool(true)); + assert_eq!( + parse(" [ 1 , 2.0 , 3 ] ").unwrap(), + Json::Array(vec![Json::Integer(1), Json::Number(2.0), Json::Integer(3)]) + ); + let mut obj = HashMap::new(); + obj.insert("key", Json::String("value")); + assert_eq!(parse(r#" { "key" : "value" } "#).unwrap(), Json::Object(obj)); + } + + #[test] + fn test_empty_collections() { + assert_eq!(parse("[]").unwrap(), Json::Array(vec![])); + assert_eq!(parse("{}").unwrap(), Json::Object(HashMap::new())); + assert_eq!(parse("[ ]").unwrap(), Json::Array(vec![])); + assert_eq!(parse("{ }").unwrap(), Json::Object(HashMap::new())); + } + + #[test] + fn test_nested_structures() { + assert_eq!( + parse(r#"{"nested":{"inner":"value"}}"#).unwrap(), + Json::Object({ + let mut outer = HashMap::new(); + let mut inner = HashMap::new(); + inner.insert("inner", Json::String("value")); + outer.insert("nested", Json::Object(inner)); + outer + }) + ); + assert_eq!( + parse("[[1,2],[3,4]]").unwrap(), + Json::Array(vec![ + Json::Array(vec![Json::Integer(1), Json::Integer(2)]), + Json::Array(vec![Json::Integer(3), Json::Integer(4)]) + ]) + ); + } + + #[test] + fn test_numbers() { + assert_eq!(parse("0").unwrap(), Json::Integer(0)); + assert_eq!(parse("-42").unwrap(), Json::Integer(-42)); + assert_eq!(parse("3.14").unwrap(), Json::Number(3.14)); + assert_eq!(parse("-3.14").unwrap(), Json::Number(-3.14)); + assert_eq!(parse("1e10").unwrap(), Json::Number(1e10)); + assert_eq!(parse("1E10").unwrap(), Json::Number(1E10)); + assert_eq!(parse("1e-10").unwrap(), Json::Number(1e-10)); + assert_eq!(parse("1.5e+10").unwrap(), Json::Number(1.5e+10)); + } + + #[test] + fn test_strings() { + assert_eq!(parse(r#""""#).unwrap(), Json::String("")); + assert_eq!(parse(r#""hello world""#).unwrap(), Json::String("hello world")); + assert_eq!( + parse(r#""with spaces and 123""#).unwrap(), + Json::String("with spaces and 123") + ); + } + + #[test] + fn test_mixed_array() { + assert_eq!( + parse(r#"[null, true, false, 35.1, 42, "text", [], {}]"#).unwrap(), + Json::Array(vec![ + Json::Null, + Json::Bool(true), + Json::Bool(false), + Json::Number(35.1), + Json::Integer(42), + Json::String("text"), + Json::Array(vec![]), + Json::Object(HashMap::new()) + ]) + ); + } + + #[test] + fn test_object_multiple_keys() { + let mut obj = HashMap::new(); + obj.insert("a", Json::Integer(1)); + obj.insert("b", Json::Bool(true)); + obj.insert("c", Json::Null); + assert_eq!(parse(r#"{"a":1,"b":true,"c":null}"#).unwrap(), Json::Object(obj)); + } + + #[test] + fn test_error_cases() { + assert!(parse("").is_err()); + assert!(parse("nul").is_err()); + assert!(parse("tru").is_err()); // typos:ignore + assert!(parse("fals").is_err()); // typos:ignore + assert!(parse(r#""unterminated"#).is_err()); + assert!(parse("[1,2,]").is_err()); + assert!(parse(r#"{"key""#).is_err()); + assert!(parse(r#"{"key":"value""#).is_err()); + assert!(parse(r#"{"key":"value",}"#).is_err()); + assert!(parse("invalid").is_err()); + assert!(parse("[1 2]").is_err()); + assert!(parse(r#"{"key":"value" "key2":"value2"}"#).is_err()); + } +} diff --git a/src/luau/mod.rs b/src/luau/mod.rs new file mode 100644 index 00000000..5bb54069 --- /dev/null +++ b/src/luau/mod.rs @@ -0,0 +1,159 @@ +//! Luau-specific extensions and types. +//! +//! This module provides Luau-specific functionality including custom [`require`] implementations, +//! heap memory analysis, and Luau VM integration utilities. +//! +//! [`require`]: crate::Lua::create_require_function + +use std::ffi::{CStr, CString}; +use std::os::raw::c_int; +use std::ptr; + +use crate::chunk::ChunkMode; +use crate::error::{Error, Result}; +use crate::function::Function; +use crate::state::{ExtraData, Lua, callback_error_ext}; +use crate::traits::{FromLuaMulti, IntoLua}; +use crate::types::MaybeSend; + +pub use heap_dump::HeapDump; +pub use require::{FsRequirer, NavigateError, Require}; + +// Since Luau has some missing standard functions, we re-implement them here + +impl Lua { + /// Create a custom Luau `require` function using provided [`Require`] implementation to find + /// and load modules. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn create_require_function(&self, require: R) -> Result { + require::create_require_function(self, require) + } + + /// Set the memory category for subsequent allocations from this Lua state. + /// + /// The category "main" is reserved for the default memory category. + /// Maximum of 255 categories can be registered. + /// The category is set per Lua thread (state) and affects all allocations made from that + /// thread. + /// + /// Return error if too many categories are registered or if the category name is invalid. + /// + /// See [`Lua::heap_dump`] for tracking memory usage by category. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_memory_category(&self, category: &str) -> Result<()> { + let lua = self.lock(); + + if category.contains(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_')) { + return Err(Error::runtime("invalid memory category name")); + } + let cat_id = unsafe { + let extra = ExtraData::get(lua.state()); + match ((*extra).mem_categories.iter().enumerate()) + .find(|&(_, name)| name.as_bytes() == category.as_bytes()) + { + Some((id, _)) => id as u8, + None => { + let new_id = (*extra).mem_categories.len() as u8; + if new_id == 255 { + return Err(Error::runtime("too many memory categories registered")); + } + (*extra).mem_categories.push(CString::new(category).unwrap()); + new_id + } + } + }; + unsafe { ffi::lua_setmemcat(lua.state(), cat_id as i32) }; + + Ok(()) + } + + /// Dumps the current Lua VM heap state. + /// + /// The returned `HeapDump` can be used to analyze memory usage. + /// It's recommended to call [`Lua::gc_collect`] before dumping the heap. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn heap_dump(&self) -> Result { + let lua = self.lock(); + unsafe { heap_dump::HeapDump::new(lua.state()).ok_or_else(|| Error::runtime("unable to dump heap")) } + } + + pub(crate) unsafe fn configure_luau(&self) -> Result<()> { + let globals = self.globals(); + + globals.raw_set("collectgarbage", self.create_c_function(lua_collectgarbage)?)?; + globals.raw_set("loadstring", self.create_c_function(lua_loadstring)?)?; + + // Set `_VERSION` global to include version number + // The environment variable `LUAU_VERSION` set by the build script + if let Some(version) = ffi::luau_version() { + globals.raw_set("_VERSION", format!("Luau {version}"))?; + } + + // Enable default `require` implementation + let require = self.create_require_function(FsRequirer::new())?; + self.globals().raw_set("require", require)?; + + Ok(()) + } +} + +unsafe extern "C-unwind" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int { + let option = ffi::luaL_optstring(state, 1, cstr!("collect")); + let option = CStr::from_ptr(option); + let arg = ffi::luaL_optinteger(state, 2, 0); + let is_sandboxed = (*ExtraData::get(state)).sandboxed; + match option.to_str() { + Ok("collect") if !is_sandboxed => { + ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); + 0 + } + Ok("stop") if !is_sandboxed => { + ffi::lua_gc(state, ffi::LUA_GCSTOP, 0); + 0 + } + Ok("restart") if !is_sandboxed => { + ffi::lua_gc(state, ffi::LUA_GCRESTART, 0); + 0 + } + Ok("count") => { + let kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0) as ffi::lua_Number; + let kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0) as ffi::lua_Number; + ffi::lua_pushnumber(state, kbytes + kbytes_rem / 1024.0); + 1 + } + Ok("step") if !is_sandboxed => { + let res = ffi::lua_gc(state, ffi::LUA_GCSTEP, arg as _); + ffi::lua_pushboolean(state, res); + 1 + } + Ok("isrunning") if !is_sandboxed => { + let res = ffi::lua_gc(state, ffi::LUA_GCISRUNNING, 0); + ffi::lua_pushboolean(state, res); + 1 + } + _ => ffi::luaL_error(state, cstr!("collectgarbage called with invalid option")), + } +} + +unsafe extern "C-unwind" fn lua_loadstring(state: *mut ffi::lua_State) -> c_int { + callback_error_ext(state, ptr::null_mut(), false, move |extra, nargs| { + let rawlua = (*extra).raw_lua(); + let (chunk, chunk_name) = + <(String, Option)>::from_stack_args(nargs, 1, Some("loadstring"), rawlua)?; + let chunk_name = chunk_name.as_deref().unwrap_or("=(loadstring)"); + (rawlua.lua()) + .load(chunk) + .set_name(chunk_name) + .set_mode(ChunkMode::Text) + .into_function()? + .push_into_stack(rawlua)?; + Ok(1) + }) +} + +mod heap_dump; +mod json; +mod require; diff --git a/src/luau/require.rs b/src/luau/require.rs new file mode 100644 index 00000000..3aee6f27 --- /dev/null +++ b/src/luau/require.rs @@ -0,0 +1,470 @@ +use std::cell::RefCell; +use std::ffi::CStr; +use std::io::Result as IoResult; +use std::ops::{Deref, DerefMut}; +use std::os::raw::{c_char, c_int, c_void}; +use std::result::Result as StdResult; +use std::{fmt, mem, ptr}; + +use crate::error::{Error, Result}; +use crate::function::Function; +use crate::state::{Lua, callback_error_ext}; +use crate::table::Table; +use crate::types::MaybeSend; + +pub use fs::FsRequirer; + +/// An error that can occur during navigation in the Luau `require-by-string` system. +#[derive(Debug, Clone)] +pub enum NavigateError { + Ambiguous, + NotFound, + Other(Error), +} + +#[cfg(feature = "luau")] +trait IntoNavigateResult { + fn into_nav_result(self) -> Result; +} + +#[cfg(feature = "luau")] +impl IntoNavigateResult for StdResult<(), NavigateError> { + fn into_nav_result(self) -> Result { + match self { + Ok(()) => Ok(ffi::luarequire_NavigateResult::Success), + Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous), + Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound), + Err(NavigateError::Other(err)) => Err(err), + } + } +} + +impl From for NavigateError { + fn from(err: Error) -> Self { + NavigateError::Other(err) + } +} + +#[cfg(feature = "luau")] +type WriteResult = ffi::luarequire_WriteResult; + +#[cfg(feature = "luau")] +type ConfigStatus = ffi::luarequire_ConfigStatus; + +/// A trait for handling modules loading and navigation in the Luau `require-by-string` system. +pub trait Require { + /// Returns `true` if "require" is permitted for the given chunk name. + fn is_require_allowed(&self, chunk_name: &str) -> bool; + + /// Resets the internal state to point at the requirer module. + fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>; + + /// Resets the internal state to point at an aliased module. + /// + /// This function received an exact path from a configuration file. + /// It's only called when an alias's path cannot be resolved relative to its + /// configuration file. + fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>; + + // Navigate to parent directory + fn to_parent(&mut self) -> StdResult<(), NavigateError>; + + /// Navigate to the given child directory. + fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>; + + /// Returns whether the context is currently pointing at a module. + fn has_module(&self) -> bool; + + /// Provides a cache key representing the current module. + /// + /// This function is only called if `has_module` returns true. + fn cache_key(&self) -> String; + + /// Returns whether a configuration is present in the current context. + fn has_config(&self) -> bool; + + /// Returns the contents of the configuration file in the current context. + /// + /// This function is only called if `has_config` returns true. + fn config(&self) -> IoResult>; + + /// Returns a loader function for the current module, that when called, loads the module + /// and returns the result. + /// + /// Loader can be sync or async. + /// This function is only called if `has_module` returns true. + fn loader(&self, lua: &Lua) -> Result; +} + +impl fmt::Debug for dyn Require { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "") + } +} + +struct Context { + require: Box, + config_cache: Option>>, +} + +impl Deref for Context { + type Target = dyn Require; + + fn deref(&self) -> &Self::Target { + &*self.require + } +} + +impl DerefMut for Context { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.require + } +} + +impl Context { + fn new(require: impl Require + MaybeSend + 'static) -> Self { + Context { + require: Box::new(require), + config_cache: None, + } + } +} + +macro_rules! try_borrow { + ($state:expr, $ctx:expr) => { + match (*($ctx as *const RefCell)).try_borrow() { + Ok(ctx) => ctx, + Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")), + } + }; +} + +macro_rules! try_borrow_mut { + ($state:expr, $ctx:expr) => { + match (*($ctx as *const RefCell)).try_borrow_mut() { + Ok(ctx) => ctx, + Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")), + } + }; +} + +#[cfg(feature = "luau")] +pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) { + if config.is_null() { + return; + } + + unsafe extern "C-unwind" fn is_require_allowed( + state: *mut ffi::lua_State, + ctx: *mut c_void, + requirer_chunkname: *const c_char, + ) -> bool { + if requirer_chunkname.is_null() { + return false; + } + + let this = try_borrow!(state, ctx); + let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy(); + this.is_require_allowed(&chunk_name) + } + + unsafe extern "C-unwind" fn reset( + state: *mut ffi::lua_State, + ctx: *mut c_void, + requirer_chunkname: *const c_char, + ) -> ffi::luarequire_NavigateResult { + let mut this = try_borrow_mut!(state, ctx); + let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy(); + callback_error_ext(state, ptr::null_mut(), true, move |_, _| { + this.reset(&chunk_name).into_nav_result() + }) + } + + unsafe extern "C-unwind" fn jump_to_alias( + state: *mut ffi::lua_State, + ctx: *mut c_void, + path: *const c_char, + ) -> ffi::luarequire_NavigateResult { + let mut this = try_borrow_mut!(state, ctx); + let path = CStr::from_ptr(path).to_string_lossy(); + callback_error_ext(state, ptr::null_mut(), true, move |_, _| { + this.jump_to_alias(&path).into_nav_result() + }) + } + + unsafe extern "C-unwind" fn to_parent( + state: *mut ffi::lua_State, + ctx: *mut c_void, + ) -> ffi::luarequire_NavigateResult { + let mut this = try_borrow_mut!(state, ctx); + callback_error_ext(state, ptr::null_mut(), true, move |_, _| { + this.to_parent().into_nav_result() + }) + } + + unsafe extern "C-unwind" fn to_child( + state: *mut ffi::lua_State, + ctx: *mut c_void, + name: *const c_char, + ) -> ffi::luarequire_NavigateResult { + let mut this = try_borrow_mut!(state, ctx); + let name = CStr::from_ptr(name).to_string_lossy(); + callback_error_ext(state, ptr::null_mut(), true, move |_, _| { + this.to_child(&name).into_nav_result() + }) + } + + unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool { + let this = try_borrow!(state, ctx); + this.has_module() + } + + unsafe extern "C-unwind" fn get_chunkname( + _state: *mut ffi::lua_State, + _ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> WriteResult { + write_to_buffer(buffer, buffer_size, size_out, &[]) + } + + unsafe extern "C-unwind" fn get_loadname( + _state: *mut ffi::lua_State, + _ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> WriteResult { + write_to_buffer(buffer, buffer_size, size_out, &[]) + } + + unsafe extern "C-unwind" fn get_cache_key( + state: *mut ffi::lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> WriteResult { + let this = try_borrow!(state, ctx); + let cache_key = this.cache_key(); + write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes()) + } + + unsafe extern "C-unwind" fn get_config_status( + state: *mut ffi::lua_State, + ctx: *mut c_void, + ) -> ConfigStatus { + let mut this = try_borrow_mut!(state, ctx); + if this.has_config() { + this.config_cache = Some(this.config()); + if let Some(Ok(data)) = &this.config_cache { + return detect_config_format(data); + } + } + ConfigStatus::Absent + } + + unsafe extern "C-unwind" fn get_config( + state: *mut ffi::lua_State, + ctx: *mut c_void, + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + ) -> WriteResult { + let mut this = try_borrow_mut!(state, ctx); + let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| { + Ok(this.config_cache.take().unwrap_or_else(|| this.config())?) + }); + write_to_buffer(buffer, buffer_size, size_out, &config) + } + + unsafe extern "C-unwind" fn load( + state: *mut ffi::lua_State, + ctx: *mut c_void, + _path: *const c_char, + _chunkname: *const c_char, + _loadname: *const c_char, + ) -> c_int { + let this = try_borrow!(state, ctx); + callback_error_ext(state, ptr::null_mut(), true, move |extra, _| { + let rawlua = (*extra).raw_lua(); + let loader = this.loader(rawlua.lua())?; + rawlua.push(loader)?; + Ok(1) + }) + } + + (*config).is_require_allowed = is_require_allowed; + (*config).reset = reset; + (*config).jump_to_alias = jump_to_alias; + (*config).to_alias_override = None; + (*config).to_alias_fallback = None; + (*config).to_parent = to_parent; + (*config).to_child = to_child; + (*config).is_module_present = is_module_present; + (*config).get_chunkname = get_chunkname; + (*config).get_loadname = get_loadname; + (*config).get_cache_key = get_cache_key; + (*config).get_config_status = get_config_status; + (*config).get_alias = None; + (*config).get_config = Some(get_config); + (*config).load = load; +} + +/// Detect configuration file format (JSON or Luau) +#[cfg(feature = "luau")] +fn detect_config_format(data: &[u8]) -> ConfigStatus { + let data = data.trim_ascii(); + if data.starts_with(b"{") { + let data = &data[1..].trim_ascii_start(); + if data.starts_with(b"\"") || data == b"}" { + return ConfigStatus::PresentJson; + } + } + ConfigStatus::PresentLuau +} + +/// Helper function to write data to a buffer +#[cfg(feature = "luau")] +unsafe fn write_to_buffer( + buffer: *mut c_char, + buffer_size: usize, + size_out: *mut usize, + data: &[u8], +) -> WriteResult { + // the buffer must be null terminated as it's a c++ `std::string` data() buffer + let is_null_terminated = data.last() == Some(&0); + *size_out = data.len() + if is_null_terminated { 0 } else { 1 }; + if *size_out > buffer_size { + return WriteResult::BufferTooSmall; + } + ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len()); + if !is_null_terminated { + *buffer.add(data.len()) = 0; + } + WriteResult::Success +} + +#[cfg(feature = "luau")] +pub(super) fn create_require_function( + lua: &Lua, + require: R, +) -> Result { + unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int { + let mut ar: ffi::lua_Debug = mem::zeroed(); + for level in 2.. { + if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 { + ffi::luaL_error(state, cstr!("require is not supported in this context")); + } + if CStr::from_ptr(ar.what) != c"C" { + break; + } + } + ffi::lua_pushstring(state, ar.source); + 1 + } + + unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int { + let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1)); + let ctx = try_borrow!(state, ctx); + let cache_key = ctx.cache_key(); + ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len()); + 1 + } + + let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe { + lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| { + let context = Context::new(require); + let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context)); + ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1); + ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file")); + ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE); + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE")); + }) + }?; + + unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_where(state, 1); + ffi::lua_pushvalue(state, 1); + ffi::lua_concat(state, 2); + ffi::lua_error(state); + } + + unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1))); + 1 + } + + unsafe extern "C-unwind" fn to_lowercase(state: *mut ffi::lua_State) -> c_int { + let s = ffi::luaL_checkstring(state, 1); + let s = CStr::from_ptr(s); + if !s.to_bytes().iter().any(|&c| c.is_ascii_uppercase()) { + // If the string does not contain any uppercase ASCII letters, return it as is + return 1; + } + callback_error_ext(state, ptr::null_mut(), true, |extra, _| { + let s = (s.to_bytes().iter()) + .map(|&c| c.to_ascii_lowercase()) + .collect::(); + (*extra).raw_lua().push(s).map(|_| 1) + }) + } + + let (error, r#type, to_lowercase) = unsafe { + lua.exec_raw::<(Function, Function, Function)>((), move |state| { + ffi::lua_pushcfunctiond(state, error, cstr!("error")); + ffi::lua_pushcfunctiond(state, r#type, cstr!("type")); + ffi::lua_pushcfunctiond(state, to_lowercase, cstr!("to_lowercase")); + }) + }?; + + // Prepare environment for the "require" function + let env = lua.create_table_with_capacity(0, 7)?; + env.raw_set("get_cache_key", get_cache_key)?; + env.raw_set("find_current_file", find_current_file)?; + env.raw_set("proxyrequire", proxyrequire)?; + env.raw_set("REGISTERED_MODULES", registered_modules)?; + env.raw_set("LOADER_CACHE", loader_cache)?; + env.raw_set("error", error)?; + env.raw_set("type", r#type)?; + env.raw_set("to_lowercase", to_lowercase)?; + + lua.load( + r#" + local path = ... + if type(path) ~= "string" then + error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")") + end + + -- Check if the module (path) is explicitly registered + local maybe_result = REGISTERED_MODULES[to_lowercase(path)] + if maybe_result ~= nil then + return maybe_result + end + + local loader = proxyrequire(path, find_current_file()) + local cache_key = get_cache_key() + -- Check if the loader result is already cached + local result = LOADER_CACHE[cache_key] + if result ~= nil then + return result + end + + -- Call the loader function and cache the result + result = loader() + if result == nil then + result = true + end + LOADER_CACHE[cache_key] = result + return result + "#, + ) + .try_cache() + .set_name("=__mlua_require") + .set_environment(env) + .into_function() +} + +mod fs; diff --git a/src/luau/require/fs.rs b/src/luau/require/fs.rs new file mode 100644 index 00000000..f6373434 --- /dev/null +++ b/src/luau/require/fs.rs @@ -0,0 +1,278 @@ +use std::collections::VecDeque; +use std::io::Result as IoResult; +use std::path::{Component, Path, PathBuf}; +use std::result::Result as StdResult; +use std::{env, fs}; + +use crate::error::Result; +use crate::function::Function; +use crate::state::Lua; + +use super::{NavigateError, Require}; + +/// The standard implementation of Luau `require-by-string` navigation. +#[derive(Default, Debug)] +pub struct FsRequirer { + /// An absolute path to the current Luau module (not mapped to a physical file) + abs_path: PathBuf, + /// A relative path to the current Luau module (not mapped to a physical file) + rel_path: PathBuf, + /// A physical path to the current Luau module, which is a file or a directory with an + /// `init.lua(u)` file + resolved_path: Option, +} + +impl FsRequirer { + /// The prefix used for chunk names in the require system. + /// Only chunk names starting with this prefix are allowed to be used in `require`. + const CHUNK_PREFIX: &str = "@"; + + /// The file extensions that are considered valid for Luau modules. + const FILE_EXTENSIONS: &[&str] = &["luau", "lua"]; + + /// The filename for the JSON configuration file. + const LUAURC_CONFIG_FILENAME: &str = ".luaurc"; + + /// The filename for the Luau configuration file. + const LUAU_CONFIG_FILENAME: &str = ".config.luau"; + + /// Creates a new `FsRequirer` instance. + pub fn new() -> Self { + Self::default() + } + + fn normalize_chunk_name(chunk_name: &str) -> &str { + if let Some((path, line)) = chunk_name.rsplit_once(':') + && line.parse::().is_ok() + { + return path; + } + chunk_name + } + + // Normalizes the path by removing unnecessary components + fn normalize_path(path: &Path) -> PathBuf { + let mut components = VecDeque::new(); + + for comp in path.components() { + match comp { + Component::Prefix(..) | Component::RootDir => { + components.push_back(comp); + } + Component::CurDir => {} + Component::ParentDir => { + if matches!(components.back(), None | Some(Component::ParentDir)) { + components.push_back(Component::ParentDir); + } else if matches!(components.back(), Some(Component::Normal(..))) { + components.pop_back(); + } + } + Component::Normal(..) => components.push_back(comp), + } + } + + if matches!(components.front(), None | Some(Component::Normal(..))) { + components.push_front(Component::CurDir); + } + + // Join the components back together + components.into_iter().collect() + } + + /// Resolve a Luau module path to a physical file or directory. + /// + /// Empty directories without init files are considered valid as "intermediate" directories. + fn resolve_module(path: &Path) -> StdResult, NavigateError> { + let mut found_path = None; + + if path.components().next_back() != Some(Component::Normal("init".as_ref())) { + let current_ext = (path.extension().and_then(|s| s.to_str())) + .map(|s| format!("{s}.")) + .unwrap_or_default(); + for ext in Self::FILE_EXTENSIONS { + let candidate = path.with_extension(format!("{current_ext}{ext}")); + if candidate.is_file() && found_path.replace(candidate).is_some() { + return Err(NavigateError::Ambiguous); + } + } + } + if path.is_dir() { + for component in Self::FILE_EXTENSIONS.iter().map(|ext| format!("init.{ext}")) { + let candidate = path.join(component); + if candidate.is_file() && found_path.replace(candidate).is_some() { + return Err(NavigateError::Ambiguous); + } + } + + if found_path.is_none() { + // Directories without init files are considered valid "intermediate" path + return Ok(None); + } + } + + Ok(Some(found_path.ok_or(NavigateError::NotFound)?)) + } +} + +impl Require for FsRequirer { + fn is_require_allowed(&self, chunk_name: &str) -> bool { + chunk_name.starts_with(Self::CHUNK_PREFIX) + } + + fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError> { + if !chunk_name.starts_with(Self::CHUNK_PREFIX) { + return Err(NavigateError::NotFound); + } + let chunk_name = Self::normalize_chunk_name(&chunk_name[1..]); + let chunk_path = Self::normalize_path(chunk_name.as_ref()); + + if chunk_path.extension() == Some("rs".as_ref()) { + // Special case for Rust source files, reset to the current directory + let chunk_filename = chunk_path.file_name().unwrap(); + let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?; + self.abs_path = Self::normalize_path(&cwd.join(chunk_filename)); + self.rel_path = ([Component::CurDir, Component::Normal(chunk_filename)].into_iter()).collect(); + self.resolved_path = None; + + return Ok(()); + } + + if chunk_path.is_absolute() { + let resolved_path = Self::resolve_module(&chunk_path)?; + self.abs_path = chunk_path.clone(); + self.rel_path = chunk_path; + self.resolved_path = resolved_path; + } else { + // Relative path + let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?; + let abs_path = Self::normalize_path(&cwd.join(&chunk_path)); + let resolved_path = Self::resolve_module(&abs_path)?; + self.abs_path = abs_path; + self.rel_path = chunk_path; + self.resolved_path = resolved_path; + } + + Ok(()) + } + + fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> { + let path = Self::normalize_path(path.as_ref()); + let resolved_path = Self::resolve_module(&path)?; + + self.abs_path = path.clone(); + self.rel_path = path; + self.resolved_path = resolved_path; + + Ok(()) + } + + fn to_parent(&mut self) -> StdResult<(), NavigateError> { + let mut abs_path = self.abs_path.clone(); + if !abs_path.pop() { + // It's important to return `NotFound` if we reached the root, as it's a "recoverable" error if we + // cannot go beyond the root directory. + // Luau "require-by-string` has a special logic to search for config file to resolve aliases. + return Err(NavigateError::NotFound); + } + let mut rel_parent = self.rel_path.clone(); + rel_parent.pop(); + let resolved_path = Self::resolve_module(&abs_path)?; + + self.abs_path = abs_path; + self.rel_path = Self::normalize_path(&rel_parent); + self.resolved_path = resolved_path; + + Ok(()) + } + + fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> { + let abs_path = self.abs_path.join(name); + let rel_path = self.rel_path.join(name); + let resolved_path = Self::resolve_module(&abs_path)?; + + self.abs_path = abs_path; + self.rel_path = rel_path; + self.resolved_path = resolved_path; + + Ok(()) + } + + fn has_module(&self) -> bool { + (self.resolved_path.as_deref()) + .map(Path::is_file) + .unwrap_or(false) + } + + fn cache_key(&self) -> String { + self.resolved_path.as_deref().unwrap().display().to_string() + } + + fn has_config(&self) -> bool { + self.abs_path.is_dir() && self.abs_path.join(Self::LUAURC_CONFIG_FILENAME).is_file() + || self.abs_path.is_dir() && self.abs_path.join(Self::LUAU_CONFIG_FILENAME).is_file() + } + + fn config(&self) -> IoResult> { + if self.abs_path.join(Self::LUAURC_CONFIG_FILENAME).is_file() { + return fs::read(self.abs_path.join(Self::LUAURC_CONFIG_FILENAME)); + } + fs::read(self.abs_path.join(Self::LUAU_CONFIG_FILENAME)) + } + + fn loader(&self, lua: &Lua) -> Result { + let name = format!("@{}", self.rel_path.display()); + lua.load(self.resolved_path.as_deref().unwrap()) + .set_name(name) + .into_function() + } +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use super::FsRequirer; + + #[test] + fn test_path_normalize() { + for (input, expected) in [ + // Basic formatting checks + ("", "./"), + (".", "./"), + ("a/relative/path", "./a/relative/path"), + // Paths containing extraneous '.' and '/' symbols + ("./remove/extraneous/symbols/", "./remove/extraneous/symbols"), + ("./remove/extraneous//symbols", "./remove/extraneous/symbols"), + ("./remove/extraneous/symbols/.", "./remove/extraneous/symbols"), + ("./remove/extraneous/./symbols", "./remove/extraneous/symbols"), + ("../remove/extraneous/symbols/", "../remove/extraneous/symbols"), + ("../remove/extraneous//symbols", "../remove/extraneous/symbols"), + ("../remove/extraneous/symbols/.", "../remove/extraneous/symbols"), + ("../remove/extraneous/./symbols", "../remove/extraneous/symbols"), + ("/remove/extraneous/symbols/", "/remove/extraneous/symbols"), + ("/remove/extraneous//symbols", "/remove/extraneous/symbols"), + ("/remove/extraneous/symbols/.", "/remove/extraneous/symbols"), + ("/remove/extraneous/./symbols", "/remove/extraneous/symbols"), + // Paths containing '..' + ("./remove/me/..", "./remove"), + ("./remove/me/../", "./remove"), + ("../remove/me/..", "../remove"), + ("../remove/me/../", "../remove"), + ("/remove/me/..", "/remove"), + ("/remove/me/../", "/remove"), + ("./..", "../"), + ("./../", "../"), + ("../..", "../../"), + ("../../", "../../"), + // '..' disappears if path is absolute and component is non-erasable + ("/../", "/"), + ] { + let path = FsRequirer::normalize_path(input.as_ref()); + assert_eq!( + &path, + expected.as_ref() as &Path, + "wrong normalization for {input}" + ); + } + } +} diff --git a/src/macros.rs b/src/macros.rs index eb49c157..5c487efb 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -10,8 +10,7 @@ macro_rules! bug_msg { macro_rules! cstr { ($s:expr) => { - concat!($s, "\0") as *const str as *const [::std::os::raw::c_char] - as *const ::std::os::raw::c_char + concat!($s, "\0") as *const str as *const [::std::os::raw::c_char] as *const ::std::os::raw::c_char }; } @@ -101,9 +100,15 @@ macro_rules! protect_lua { }; ($state:expr, $nargs:expr, $nresults:expr, fn($state_inner:ident) $code:expr) => {{ - unsafe extern "C" fn do_call($state_inner: *mut ffi::lua_State) -> ::std::os::raw::c_int { + use ::std::os::raw::c_int; + unsafe extern "C-unwind" fn do_call($state_inner: *mut ffi::lua_State) -> c_int { $code; - $nresults + let nresults = $nresults; + if nresults == ::ffi::LUA_MULTRET { + ffi::lua_gettop($state_inner) + } else { + nresults + } } crate::util::protect_lua_call($state, $nargs, do_call) diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 00000000..a484e277 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,151 @@ +use std::alloc::{self, Layout}; +use std::os::raw::c_void; +use std::ptr; + +pub(crate) static ALLOCATOR: ffi::lua_Alloc = allocator; + +#[repr(C)] +#[derive(Default)] +pub(crate) struct MemoryState { + used_memory: isize, + memory_limit: isize, + // Can be set to temporary ignore the memory limit. + // This is used when calling `lua_pushcfunction` for lua5.1/jit/luau. + ignore_limit: bool, + // Indicates that the memory limit was reached on the last allocation. + #[cfg(feature = "luau")] + limit_reached: bool, +} + +impl MemoryState { + #[cfg(feature = "luau")] + #[inline] + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + let mut mem_state = ptr::null_mut(); + ffi::lua_getallocf(state, &mut mem_state); + mlua_assert!(!mem_state.is_null(), "Luau state has no allocator userdata"); + mem_state as *mut MemoryState + } + + #[cfg(not(feature = "luau"))] + #[inline] + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + let mut mem_state = ptr::null_mut(); + if !ptr::fn_addr_eq(ffi::lua_getallocf(state, &mut mem_state), ALLOCATOR) { + mem_state = ptr::null_mut(); + } + mem_state as *mut MemoryState + } + + #[inline] + pub(crate) fn used_memory(&self) -> usize { + self.used_memory as usize + } + + #[inline] + pub(crate) fn memory_limit(&self) -> usize { + self.memory_limit as usize + } + + #[inline] + pub(crate) fn set_memory_limit(&mut self, limit: usize) -> usize { + let prev_limit = self.memory_limit; + self.memory_limit = limit as isize; + prev_limit as usize + } + + // This function is used primarily for calling `lua_pushcfunction` in lua5.1/jit/luau + // to bypass the memory limit (if set). + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + #[inline] + pub(crate) unsafe fn relax_limit_with(state: *mut ffi::lua_State, f: impl FnOnce()) { + let mem_state = Self::get(state); + if !mem_state.is_null() { + (*mem_state).ignore_limit = true; + f(); + (*mem_state).ignore_limit = false; + } else { + f(); + } + } + + // Does nothing apart from calling `f()`, we don't need to bypass any limits + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + #[inline] + pub(crate) unsafe fn relax_limit_with(_state: *mut ffi::lua_State, f: impl FnOnce()) { + f(); + } + + // Returns `true` if the memory limit was reached on the last memory operation + #[cfg(feature = "luau")] + #[inline] + pub(crate) unsafe fn limit_reached(state: *mut ffi::lua_State) -> bool { + (*Self::get(state)).limit_reached + } +} + +unsafe extern "C" fn allocator( + extra: *mut c_void, + ptr: *mut c_void, + osize: usize, + nsize: usize, +) -> *mut c_void { + let mem_state = &mut *(extra as *mut MemoryState); + #[cfg(feature = "luau")] + { + // Reset the flag + mem_state.limit_reached = false; + } + + if nsize == 0 { + // Free memory + if !ptr.is_null() { + let layout = Layout::from_size_align_unchecked(osize, ffi::SYS_MIN_ALIGN); + alloc::dealloc(ptr as *mut u8, layout); + mem_state.used_memory -= osize as isize; + } + return ptr::null_mut(); + } + + // Do not allocate more than isize::MAX + if nsize > isize::MAX as usize { + return ptr::null_mut(); + } + + // Are we fit to the memory limits? + let mut mem_diff = nsize as isize; + if !ptr.is_null() { + mem_diff -= osize as isize; + } + let mem_limit = mem_state.memory_limit; + let new_used_memory = mem_state.used_memory + mem_diff; + if mem_limit > 0 && new_used_memory > mem_limit && !mem_state.ignore_limit { + #[cfg(feature = "luau")] + { + mem_state.limit_reached = true; + } + return ptr::null_mut(); + } + mem_state.used_memory += mem_diff; + + if ptr.is_null() { + // Allocate new memory + let new_layout = match Layout::from_size_align(nsize, ffi::SYS_MIN_ALIGN) { + Ok(layout) => layout, + Err(_) => return ptr::null_mut(), + }; + let new_ptr = alloc::alloc(new_layout) as *mut c_void; + if new_ptr.is_null() { + alloc::handle_alloc_error(new_layout); + } + return new_ptr; + } + + // Reallocate memory + let old_layout = Layout::from_size_align_unchecked(osize, ffi::SYS_MIN_ALIGN); + let new_ptr = alloc::realloc(ptr as *mut u8, old_layout, nsize) as *mut c_void; + if new_ptr.is_null() { + alloc::handle_alloc_error(old_layout); + } + new_ptr +} diff --git a/src/multi.rs b/src/multi.rs index 36ad19cb..f82f4cb5 100644 --- a/src/multi.rs +++ b/src/multi.rs @@ -1,53 +1,229 @@ -#![allow(clippy::wrong_self_convention)] - +use std::collections::{VecDeque, vec_deque}; use std::iter::FromIterator; +use std::mem; use std::ops::{Deref, DerefMut}; +use std::os::raw::c_int; use std::result::Result as StdResult; use crate::error::Result; -use crate::lua::Lua; -use crate::value::{FromLua, FromLuaMulti, MultiValue, Nil, ToLua, ToLuaMulti}; +use crate::state::{Lua, RawLua}; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::util::check_stack; +use crate::value::{Nil, Value}; -/// Result is convertible to `MultiValue` following the common Lua idiom of returning the result +/// Result is convertible to [`MultiValue`] following the common Lua idiom of returning the result /// on success, or in the case of an error, returning `nil` and an error message. -impl<'lua, T: ToLua<'lua>, E: ToLua<'lua>> ToLuaMulti<'lua> for StdResult { - fn to_lua_multi(self, lua: &'lua Lua) -> Result> { - let mut result = MultiValue::new_or_cached(lua); +impl IntoLuaMulti for StdResult { + #[inline] + fn into_lua_multi(self, lua: &Lua) -> Result { match self { - Ok(v) => result.push_front(v.to_lua(lua)?), - Err(e) => { - result.push_front(e.to_lua(lua)?); - result.push_front(Nil); - } + Ok(val) => (val,).into_lua_multi(lua), + Err(err) => (Nil, err).into_lua_multi(lua), + } + } + + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + match self { + Ok(val) => (val,).push_into_stack_multi(lua), + Err(err) => (Nil, err).push_into_stack_multi(lua), + } + } +} + +impl IntoLuaMulti for StdResult<(), E> { + #[inline] + fn into_lua_multi(self, lua: &Lua) -> Result { + match self { + Ok(_) => const { Ok(MultiValue::new()) }, + Err(err) => (Nil, err).into_lua_multi(lua), + } + } + + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + match self { + Ok(_) => Ok(0), + Err(err) => (Nil, err).push_into_stack_multi(lua), } - Ok(result) } } -impl<'lua, T: ToLua<'lua>> ToLuaMulti<'lua> for T { - fn to_lua_multi(self, lua: &'lua Lua) -> Result> { - let mut v = MultiValue::new_or_cached(lua); - v.push_front(self.to_lua(lua)?); +impl IntoLuaMulti for T { + #[inline] + fn into_lua_multi(self, lua: &Lua) -> Result { + let mut v = MultiValue::with_capacity(1); + v.push_back(self.into_lua(lua)?); Ok(v) } + + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + self.push_into_stack(lua)?; + Ok(1) + } +} + +impl FromLuaMulti for T { + #[inline] + fn from_lua_multi(mut values: MultiValue, lua: &Lua) -> Result { + T::from_lua(values.pop_front().unwrap_or(Nil), lua) + } + + #[inline] + fn from_lua_args(mut args: MultiValue, i: usize, to: Option<&str>, lua: &Lua) -> Result { + T::from_lua_arg(args.pop_front().unwrap_or(Nil), i, to, lua) + } + + #[inline] + unsafe fn from_stack_multi(nvals: c_int, lua: &RawLua) -> Result { + if nvals == 0 { + return T::from_lua(Nil, lua.lua()); + } + T::from_stack(-nvals, lua) + } + + #[inline] + unsafe fn from_stack_args(nargs: c_int, i: usize, to: Option<&str>, lua: &RawLua) -> Result { + if nargs == 0 { + return T::from_lua_arg(Nil, i, to, lua.lua()); + } + T::from_stack_arg(-nargs, i, to, lua) + } +} + +/// Multiple Lua values used for both argument passing and also for multiple return values. +#[derive(Default, Debug, Clone)] +pub struct MultiValue(VecDeque); + +impl Deref for MultiValue { + type Target = VecDeque; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MultiValue { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl MultiValue { + /// Creates an empty `MultiValue` containing no values. + #[inline] + pub const fn new() -> MultiValue { + MultiValue(VecDeque::new()) + } + + /// Creates an empty `MultiValue` container with space for at least `capacity` elements. + pub fn with_capacity(capacity: usize) -> MultiValue { + MultiValue(VecDeque::with_capacity(capacity)) + } + + /// Creates a `MultiValue` container from vector of values. + /// + /// This method works in *O*(1) time and does not allocate any additional memory. + #[inline] + pub fn from_vec(vec: Vec) -> MultiValue { + vec.into() + } + + /// Consumes the `MultiValue` and returns a vector of values. + /// + /// This method needs *O*(*n*) data movement if the circular buffer doesn't happen to be at the + /// beginning of the allocation. + #[inline] + pub fn into_vec(self) -> Vec { + self.into() + } + + #[inline] + pub(crate) fn from_lua_iter(lua: &Lua, iter: impl IntoIterator) -> Result { + let iter = iter.into_iter(); + let mut multi_value = MultiValue::with_capacity(iter.size_hint().0); + for value in iter { + multi_value.push_back(value.into_lua(lua)?); + } + Ok(multi_value) + } +} + +impl From> for MultiValue { + #[inline] + fn from(value: Vec) -> Self { + MultiValue(value.into()) + } +} + +impl From for Vec { + #[inline] + fn from(value: MultiValue) -> Self { + value.0.into() + } +} + +impl FromIterator for MultiValue { + #[inline] + fn from_iter>(iter: I) -> Self { + let mut multi_value = MultiValue::new(); + multi_value.extend(iter); + multi_value + } } -impl<'lua, T: FromLua<'lua>> FromLuaMulti<'lua> for T { - fn from_lua_multi(mut values: MultiValue<'lua>, lua: &'lua Lua) -> Result { - let res = T::from_lua(values.pop_front().unwrap_or(Nil), lua); - lua.cache_multivalue(values); - res +impl IntoIterator for MultiValue { + type Item = Value; + type IntoIter = vec_deque::IntoIter; + + #[inline] + fn into_iter(mut self) -> Self::IntoIter { + let deque = mem::take(&mut self.0); + mem::forget(self); + deque.into_iter() } } -impl<'lua> ToLuaMulti<'lua> for MultiValue<'lua> { - fn to_lua_multi(self, _: &'lua Lua) -> Result> { +impl<'a> IntoIterator for &'a MultiValue { + type Item = &'a Value; + type IntoIter = vec_deque::Iter<'a, Value>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl IntoLuaMulti for MultiValue { + #[inline] + fn into_lua_multi(self, _: &Lua) -> Result { Ok(self) } } -impl<'lua> FromLuaMulti<'lua> for MultiValue<'lua> { - fn from_lua_multi(values: MultiValue<'lua>, _: &'lua Lua) -> Result { +impl IntoLuaMulti for &MultiValue { + #[inline] + fn into_lua_multi(self, _: &Lua) -> Result { + Ok(self.clone()) + } + + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + let nresults = self.len() as i32; + check_stack(lua.state(), nresults + 1)?; + for value in &self.0 { + lua.push_value(value)?; + } + Ok(nresults) + } +} + +impl FromLuaMulti for MultiValue { + #[inline] + fn from_lua_multi(values: MultiValue, _: &Lua) -> Result { Ok(values) } } @@ -75,22 +251,46 @@ impl<'lua> FromLuaMulti<'lua> for MultiValue<'lua> { /// # Ok(()) /// # } /// ``` -/// -/// [`FromLua`]: crate::FromLua -/// [`MultiValue`]: crate::MultiValue -#[derive(Debug, Clone)] +#[derive(Default, Debug, Clone)] pub struct Variadic(Vec); impl Variadic { /// Creates an empty `Variadic` wrapper containing no values. - pub fn new() -> Variadic { + pub const fn new() -> Variadic { Variadic(Vec::new()) } + + /// Creates an empty `Variadic` container with space for at least `capacity` elements. + pub fn with_capacity(capacity: usize) -> Variadic { + Variadic(Vec::with_capacity(capacity)) + } +} + +impl Deref for Variadic { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Variadic { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From> for Variadic { + #[inline] + fn from(vec: Vec) -> Self { + Variadic(vec) + } } -impl Default for Variadic { - fn default() -> Variadic { - Variadic::new() +impl From> for Vec { + #[inline] + fn from(value: Variadic) -> Self { + value.0 } } @@ -109,82 +309,145 @@ impl IntoIterator for Variadic { } } -impl Deref for Variadic { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for Variadic { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 +impl IntoLuaMulti for Variadic { + #[inline] + fn into_lua_multi(self, lua: &Lua) -> Result { + MultiValue::from_lua_iter(lua, self) } -} -impl<'lua, T: ToLua<'lua>> ToLuaMulti<'lua> for Variadic { - fn to_lua_multi(self, lua: &'lua Lua) -> Result> { - let mut values = MultiValue::new_or_cached(lua); - values.refill(self.0.into_iter().map(|e| e.to_lua(lua)))?; - Ok(values) + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + let nresults = self.len() as i32; + check_stack(lua.state(), nresults + 1)?; + for value in self.0 { + value.push_into_stack(lua)?; + } + Ok(nresults) } } -impl<'lua, T: FromLua<'lua>> FromLuaMulti<'lua> for Variadic { - fn from_lua_multi(mut values: MultiValue<'lua>, lua: &'lua Lua) -> Result { - let res = values - .drain_all() - .map(|e| T::from_lua(e, lua)) +impl FromLuaMulti for Variadic { + #[inline] + fn from_lua_multi(mut values: MultiValue, lua: &Lua) -> Result { + values + .drain(..) + .map(|val| T::from_lua(val, lua)) .collect::>>() - .map(Variadic); - lua.cache_multivalue(values); - res + .map(Variadic) } } macro_rules! impl_tuple { () => ( - impl<'lua> ToLuaMulti<'lua> for () { - fn to_lua_multi(self, lua: &'lua Lua) -> Result> { - Ok(MultiValue::new_or_cached(lua)) + impl IntoLuaMulti for () { + #[inline] + fn into_lua_multi(self, _: &Lua) -> Result { + const { Ok(MultiValue::new()) } + } + + #[inline] + unsafe fn push_into_stack_multi(self, _lua: &RawLua) -> Result { + Ok(0) } } - impl<'lua> FromLuaMulti<'lua> for () { - fn from_lua_multi(values: MultiValue<'lua>, lua: &'lua Lua) -> Result { - lua.cache_multivalue(values); + impl FromLuaMulti for () { + #[inline] + fn from_lua_multi(_values: MultiValue, _lua: &Lua) -> Result { + Ok(()) + } + + #[inline] + unsafe fn from_stack_multi(_nvals: c_int, _lua: &RawLua) -> Result { Ok(()) } } ); ($last:ident $($name:ident)*) => ( - impl<'lua, $($name,)* $last> ToLuaMulti<'lua> for ($($name,)* $last,) - where $($name: ToLua<'lua>,)* - $last: ToLuaMulti<'lua> + impl<$($name,)* $last> IntoLuaMulti for ($($name,)* $last,) + where $($name: IntoLua,)* + $last: IntoLuaMulti { - #[allow(unused_mut)] - #[allow(non_snake_case)] - fn to_lua_multi(self, lua: &'lua Lua) -> Result> { + #[allow(unused_mut, non_snake_case)] + #[inline] + fn into_lua_multi(self, lua: &Lua) -> Result { let ($($name,)* $last,) = self; - let mut results = $last.to_lua_multi(lua)?; - push_reverse!(results, $($name.to_lua(lua)?,)*); + let mut results = $last.into_lua_multi(lua)?; + push_reverse!(results, $($name.into_lua(lua)?,)*); Ok(results) } + + #[allow(non_snake_case)] + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + let ($($name,)* $last,) = self; + let mut nresults = 0; + $( + _ = $name; + nresults += 1; + )* + check_stack(lua.state(), nresults + 1)?; + $( + $name.push_into_stack(lua)?; + )* + nresults += $last.push_into_stack_multi(lua)?; + Ok(nresults) + } } - impl<'lua, $($name,)* $last> FromLuaMulti<'lua> for ($($name,)* $last,) - where $($name: FromLua<'lua>,)* - $last: FromLuaMulti<'lua> + impl<$($name,)* $last> FromLuaMulti for ($($name,)* $last,) + where $($name: FromLua,)* + $last: FromLuaMulti { - #[allow(unused_mut)] - #[allow(non_snake_case)] - fn from_lua_multi(mut values: MultiValue<'lua>, lua: &'lua Lua) -> Result { - $(let $name = values.pop_front().unwrap_or(Nil);)* + #[allow(unused_mut, non_snake_case)] + #[inline] + fn from_lua_multi(mut values: MultiValue, lua: &Lua) -> Result { + $(let $name = FromLua::from_lua(values.pop_front().unwrap_or(Nil), lua)?;)* let $last = FromLuaMulti::from_lua_multi(values, lua)?; - Ok(($(FromLua::from_lua($name, lua)?,)* $last,)) + Ok(($($name,)* $last,)) + } + + #[allow(unused_mut, non_snake_case)] + #[inline] + fn from_lua_args(mut args: MultiValue, mut i: usize, to: Option<&str>, lua: &Lua) -> Result { + $( + let $name = FromLua::from_lua_arg(args.pop_front().unwrap_or(Nil), i, to, lua)?; + i += 1; + )* + let $last = FromLuaMulti::from_lua_args(args, i, to, lua)?; + Ok(($($name,)* $last,)) + } + + #[allow(unused_mut, non_snake_case)] + #[inline] + unsafe fn from_stack_multi(mut nvals: c_int, lua: &RawLua) -> Result { + $( + let $name = if nvals > 0 { + nvals -= 1; + FromLua::from_stack(-(nvals + 1), lua) + } else { + FromLua::from_lua(Nil, lua.lua()) + }?; + )* + let $last = FromLuaMulti::from_stack_multi(nvals, lua)?; + Ok(($($name,)* $last,)) + } + + #[allow(unused_mut, non_snake_case)] + #[inline] + unsafe fn from_stack_args(mut nargs: c_int, mut i: usize, to: Option<&str>, lua: &RawLua) -> Result { + $( + let $name = if nargs > 0 { + nargs -= 1; + FromLua::from_stack_arg(-(nargs + 1), i, to, lua) + } else { + FromLua::from_lua_arg(Nil, i, to, lua.lua()) + }?; + i += 1; + )* + let $last = FromLuaMulti::from_stack_args(nargs, i, to, lua)?; + Ok(($($name,)* $last,)) } } ); @@ -220,3 +483,13 @@ impl_tuple!(A B C D E F G H I J K L M); impl_tuple!(A B C D E F G H I J K L M N); impl_tuple!(A B C D E F G H I J K L M N O); impl_tuple!(A B C D E F G H I J K L M N O P); + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(MultiValue: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(MultiValue: Send, Sync); +} diff --git a/src/prelude.rs b/src/prelude.rs index 308f352d..fb571e4b 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -2,34 +2,49 @@ #[doc(no_inline)] pub use crate::{ - AnyUserData as LuaAnyUserData, Chunk as LuaChunk, Error as LuaError, - ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, - Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, - Integer as LuaInteger, LightUserData as LuaLightUserData, Lua, LuaOptions, - MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, - RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, String as LuaString, - Table as LuaTable, TableExt as LuaTableExt, TablePairs as LuaTablePairs, - TableSequence as LuaTableSequence, Thread as LuaThread, ThreadStatus as LuaThreadStatus, ToLua, - ToLuaMulti, UserData as LuaUserData, UserDataFields as LuaUserDataFields, + AnyUserData as LuaAnyUserData, BorrowedBytes as LuaBorrowedBytes, BorrowedStr as LuaBorrowedStr, + Either as LuaEither, Error as LuaError, FromLua, FromLuaMulti, Function as LuaFunction, + Integer as LuaInteger, IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaOptions, + LuaString, MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, + ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, + Table as LuaTable, Thread as LuaThread, UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, - Value as LuaValue, + UserDataOwned as LuaUserDataOwned, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, + UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, Variadic as LuaVariadic, + VmState as LuaVmState, WeakLua, chunk::AsChunk as AsLuaChunk, chunk::Chunk as LuaChunk, + chunk::ChunkMode as LuaChunkMode, error::ErrorContext as LuaErrorContext, + error::ExternalError as LuaExternalError, error::ExternalResult as LuaExternalResult, + function::FunctionInfo as LuaFunctionInfo, function::LuaNativeFn, function::LuaNativeFnMut, + state::GcIncParams as LuaGcIncParams, state::GcMode as LuaGcMode, table::TablePairs as LuaTablePairs, + table::TableSequence as LuaTableSequence, thread::ThreadStatus as LuaThreadStatus, }; #[cfg(not(feature = "luau"))] #[doc(no_inline)] pub use crate::HookTriggers as LuaHookTriggers; +#[cfg(any(feature = "lua54", feature = "lua55"))] +#[doc(no_inline)] +pub use crate::state::GcGenParams as LuaGcGenParams; + #[cfg(feature = "luau")] #[doc(no_inline)] -pub use crate::VmState as LuaVmState; +pub use crate::{ + Vector as LuaVector, + chunk::{CompileConstant as LuaCompileConstant, Compiler as LuaCompiler}, + luau::{ + FsRequirer as LuaFsRequirer, HeapDump as LuaHeapDump, NavigateError as LuaNavigateError, + Require as LuaRequire, + }, +}; #[cfg(feature = "async")] #[doc(no_inline)] -pub use crate::AsyncThread as LuaAsyncThread; +pub use crate::{function::LuaNativeAsyncFn, thread::AsyncThread as LuaAsyncThread}; -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] #[doc(no_inline)] pub use crate::{ - DeserializeOptions as LuaDeserializeOptions, LuaSerdeExt, + DeserializeOptions as LuaDeserializeOptions, LuaSerdeExt, SerializableValue as LuaSerializableValue, SerializeOptions as LuaSerializeOptions, }; diff --git a/src/scope.rs b/src/scope.rs index 8fd9d908..fa8cafaf 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,57 +1,43 @@ -use std::any::Any; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::marker::PhantomData; use std::mem; -use std::os::raw::{c_int, c_void}; -use std::rc::Rc; - -#[cfg(feature = "serialize")] -use serde::Serialize; use crate::error::{Error, Result}; -use crate::ffi; use crate::function::Function; -use crate::lua::Lua; -use crate::types::{Callback, CallbackUpvalue, LuaRef, MaybeSend}; -use crate::userdata::{ - AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods, -}; -use crate::util::{ - assert_stack, check_stack, get_userdata, init_userdata_metatable, push_table, rawset_field, - take_userdata, StackGuard, -}; -use crate::value::{FromLua, FromLuaMulti, MultiValue, ToLua, ToLuaMulti, Value}; - -#[cfg(feature = "lua54")] -use crate::userdata::USER_VALUE_MAXSLOT; - -#[cfg(feature = "async")] -use { - crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}, - futures_core::future::Future, - futures_util::future::{self, TryFutureExt}, -}; +use crate::state::{Lua, LuaGuard, RawLua}; +use crate::traits::{FromLuaMulti, IntoLuaMulti}; +use crate::types::{Callback, CallbackUpvalue, ScopedCallback, ValueRef}; +use crate::userdata::{AnyUserData, UserData, UserDataRegistry, UserDataStorage}; +use crate::util::{self, StackGuard, check_stack, get_metatable_ptr, get_userdata, take_userdata}; /// Constructed by the [`Lua::scope`] method, allows temporarily creating Lua userdata and -/// callbacks that are not required to be Send or 'static. +/// callbacks that are not required to be `Send` or `'static`. /// /// See [`Lua::scope`] for more details. -/// -/// [`Lua::scope`]: crate::Lua.html::scope -pub struct Scope<'lua, 'scope> { - lua: &'lua Lua, - destructors: RefCell, DestructorCallback<'lua>)>>, - _scope_invariant: PhantomData>, +pub struct Scope<'scope, 'env: 'scope> { + lua: LuaGuard, + // Internal destructors run first, then user destructors (based on the declaration order) + destructors: Destructors<'env>, + user_destructors: UserDestructors<'env>, + _scope_invariant: PhantomData<&'scope mut &'scope ()>, + _env_invariant: PhantomData<&'env mut &'env ()>, } -type DestructorCallback<'lua> = Box) -> Vec> + 'lua>; +type DestructorCallback<'a> = Box Vec>>; + +// Implement Drop on Destructors instead of Scope to avoid compilation error +struct Destructors<'a>(RefCell)>>); -impl<'lua, 'scope> Scope<'lua, 'scope> { - pub(crate) fn new(lua: &'lua Lua) -> Scope<'lua, 'scope> { +struct UserDestructors<'a>(RefCell>>); + +impl<'scope, 'env: 'scope> Scope<'scope, 'env> { + pub(crate) fn new(lua: LuaGuard) -> Self { Scope { lua, - destructors: RefCell::new(Vec::new()), + destructors: Destructors(RefCell::new(Vec::new())), + user_destructors: UserDestructors(RefCell::new(Vec::new())), _scope_invariant: PhantomData, + _env_invariant: PhantomData, } } @@ -59,27 +45,16 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { /// /// This is a version of [`Lua::create_function`] that creates a callback which expires on /// scope drop. See [`Lua::scope`] for more details. - /// - /// [`Lua::create_function`]: crate::Lua::create_function - /// [`Lua::scope`]: crate::Lua::scope - pub fn create_function<'callback, A, R, F>(&'callback self, func: F) -> Result> + pub fn create_function(&'scope self, func: F) -> Result where - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'scope + Fn(&'callback Lua, A) -> Result, + F: Fn(&Lua, A) -> Result + 'scope, + A: FromLuaMulti, + R: IntoLuaMulti, { - // Safe, because 'scope must outlive 'callback (due to Self containing 'scope), however the - // callback itself must be 'scope lifetime, so the function should not be able to capture - // anything of 'callback lifetime. 'scope can't be shortened due to being invariant, and - // the 'callback lifetime here can't be enlarged due to coming from a universal - // quantification in Lua::scope. - // - // I hope I got this explanation right, but in any case this is tested with compiletest_rs - // to make sure callbacks can't capture handles with lifetime outside the scope, inside the - // scope, and owned inside the callback itself. unsafe { - self.create_callback(Box::new(move |lua, args| { - func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) + self.create_callback(Box::new(move |rawlua, nargs| { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) })) } } @@ -88,815 +63,267 @@ impl<'lua, 'scope> Scope<'lua, 'scope> { /// /// This is a version of [`Lua::create_function_mut`] that creates a callback which expires /// on scope drop. See [`Lua::scope`] and [`Scope::create_function`] for more details. - /// - /// [`Lua::create_function_mut`]: crate::Lua::create_function_mut - /// [`Lua::scope`]: crate::Lua::scope - /// [`Scope::create_function`]: #method.create_function - pub fn create_function_mut<'callback, A, R, F>( - &'callback self, - func: F, - ) -> Result> + pub fn create_function_mut(&'scope self, func: F) -> Result where - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'scope + FnMut(&'callback Lua, A) -> Result, + F: FnMut(&Lua, A) -> Result + 'scope, + A: FromLuaMulti, + R: IntoLuaMulti, { let func = RefCell::new(func); self.create_function(move |lua, args| { - (*func - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?)(lua, args) + (*func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?)(lua, args) }) } - /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. + /// Creates a Lua userdata object from a reference to custom userdata type. /// - /// This is a version of [`Lua::create_async_function`] that creates a callback which expires on - /// scope drop. See [`Lua::scope`] and [`Lua::async_scope`] for more details. - /// - /// Requires `feature = "async"` + /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on + /// scope drop, and does not require that the userdata type be Send. This method takes + /// non-'static reference to the data. See [`Lua::scope`] for more details. /// - /// [`Lua::create_async_function`]: crate::Lua::create_async_function - /// [`Lua::scope`]: crate::Lua::scope - /// [`Lua::async_scope`]: crate::Lua::async_scope - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn create_async_function<'callback, A, R, F, FR>( - &'callback self, - func: F, - ) -> Result> + /// Userdata created with this method will not be able to be mutated from Lua. + pub fn create_userdata_ref(&'scope self, data: &'env T) -> Result where - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'scope + Fn(&'callback Lua, A) -> FR, - FR: 'callback + Future>, + T: UserData + 'static, { - unsafe { - self.create_async_callback(Box::new(move |lua, args| { - let args = match A::from_lua_multi(args, lua) { - Ok(args) => args, - Err(e) => return Box::pin(future::err(e)), - }; - Box::pin(func(lua, args).and_then(move |ret| future::ready(ret.to_lua_multi(lua)))) - })) - } + let ud = unsafe { self.lua.make_userdata(UserDataStorage::new_ref(data)) }?; + self.seal_userdata::(&ud); + Ok(ud) } - /// Create a Lua userdata object from a custom userdata type. + /// Creates a Lua userdata object from a mutable reference to custom userdata type. /// /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on - /// scope drop, and does not require that the userdata type be Send (but still requires that the - /// UserData be 'static). - /// See [`Lua::scope`] for more details. - /// - /// [`Lua::create_userdata`]: crate::Lua::create_userdata - /// [`Lua::scope`]: crate::Lua::scope - pub fn create_userdata(&self, data: T) -> Result> + /// scope drop, and does not require that the userdata type be Send. This method takes + /// non-'static mutable reference to the data. See [`Lua::scope`] for more details. + pub fn create_userdata_ref_mut(&'scope self, data: &'env mut T) -> Result where - T: 'static + UserData, + T: UserData + 'static, { - self.create_userdata_inner(UserDataCell::new(data)) + let ud = unsafe { self.lua.make_userdata(UserDataStorage::new_ref_mut(data)) }?; + self.seal_userdata::(&ud); + Ok(ud) } - /// Create a Lua userdata object from a custom serializable userdata type. - /// - /// This is a version of [`Lua::create_ser_userdata`] that creates a userdata which expires on - /// scope drop, and does not require that the userdata type be Send (but still requires that the - /// UserData be 'static). - /// See [`Lua::scope`] for more details. + /// Creates a Lua userdata object from a reference to custom Rust type. /// - /// Requires `feature = "serialize"` + /// This is a version of [`Lua::create_any_userdata`] that creates a userdata which expires on + /// scope drop, and does not require that the Rust type be Send. This method takes non-'static + /// reference to the data. See [`Lua::scope`] for more details. /// - /// [`Lua::create_ser_userdata`]: crate::Lua::create_ser_userdata - /// [`Lua::scope`]: crate::Lua::scope - #[cfg(feature = "serialize")] - #[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] - pub fn create_ser_userdata(&self, data: T) -> Result> + /// Userdata created with this method will not be able to be mutated from Lua. + pub fn create_any_userdata_ref(&'scope self, data: &'env T) -> Result where - T: 'static + UserData + Serialize, + T: 'static, { - self.create_userdata_inner(UserDataCell::new_ser(data)) + let ud = unsafe { self.lua.make_any_userdata(UserDataStorage::new_ref(data)) }?; + self.seal_userdata::(&ud); + Ok(ud) } - fn create_userdata_inner(&self, data: UserDataCell) -> Result> + /// Creates a Lua userdata object from a mutable reference to custom Rust type. + /// + /// This is a version of [`Lua::create_any_userdata`] that creates a userdata which expires on + /// scope drop, and does not require that the Rust type be Send. This method takes non-'static + /// mutable reference to the data. See [`Lua::scope`] for more details. + pub fn create_any_userdata_ref_mut(&'scope self, data: &'env mut T) -> Result where - T: 'static + UserData, + T: 'static, { - // Safe even though T may not be Send, because the parent Lua cannot be sent to another - // thread while the Scope is alive (or the returned AnyUserData handle even). - unsafe { - let ud = self.lua.make_userdata(data)?; - - #[cfg(any(feature = "lua51", feature = "luajit"))] - let newtable = self.lua.create_table()?; - let destructor: DestructorCallback = Box::new(move |ud| { - let state = ud.lua.state; - let _sg = StackGuard::new(state); - assert_stack(state, 2); - - // Check that userdata is not destructed (via `take()` call) - if ud.lua.push_userdata_ref(&ud).is_err() { - return vec![]; - } - - // Clear associated user values - #[cfg(feature = "lua54")] - for i in 1..=USER_VALUE_MAXSLOT { - ffi::lua_pushnil(state); - ffi::lua_setiuservalue(state, -2, i as c_int); - } - #[cfg(any(feature = "lua53", feature = "lua52", feature = "luau"))] - { - ffi::lua_pushnil(state); - ffi::lua_setuservalue(state, -2); - } - #[cfg(any(feature = "lua51", feature = "luajit"))] - { - ud.lua.push_ref(&newtable.0); - ffi::lua_setuservalue(state, -2); - } - - vec![Box::new(take_userdata::>(state))] - }); - self.destructors - .borrow_mut() - .push((ud.0.clone(), destructor)); - - Ok(ud) - } + let ud = unsafe { self.lua.make_any_userdata(UserDataStorage::new_ref_mut(data)) }?; + self.seal_userdata::(&ud); + Ok(ud) } - /// Create a Lua userdata object from a custom userdata type. + /// Creates a Lua userdata object from a custom userdata type. /// /// This is a version of [`Lua::create_userdata`] that creates a userdata which expires on - /// scope drop, and does not require that the userdata type be Send or 'static. See + /// scope drop, and does not require that the userdata type be `Send` or `'static`. See /// [`Lua::scope`] for more details. /// - /// Lifting the requirement that the UserData type be 'static comes with some important - /// limitations, so if you only need to eliminate the Send requirement, it is probably better to - /// use [`Scope::create_userdata`] instead. - /// /// The main limitation that comes from using non-'static userdata is that the produced userdata - /// will no longer have a `TypeId` associated with it, because `TypeId` can only work for - /// 'static types. This means that it is impossible, once the userdata is created, to get a - /// reference to it back *out* of an `AnyUserData` handle. This also implies that the + /// will no longer have a [`TypeId`] associated with it, because [`TypeId`] can only work for + /// `'static` types. This means that it is impossible, once the userdata is created, to get a + /// reference to it back *out* of an [`AnyUserData`] handle. This also implies that the /// "function" type methods that can be added via [`UserDataMethods`] (the ones that accept - /// `AnyUserData` as a first parameter) are vastly less useful. Also, there is no way to re-use - /// a single metatable for multiple non-'static types, so there is a higher cost associated with - /// creating the userdata metatable each time a new userdata is created. + /// [`AnyUserData`] as a first parameter) are vastly less useful. Also, there is no way to + /// re-use a single metatable for multiple non-'static types, so there is a higher cost + /// associated with creating the userdata metatable each time a new userdata is created. /// - /// [`Scope::create_userdata`]: #method.create_userdata - /// [`Lua::create_userdata`]: crate::Lua::create_userdata - /// [`Lua::scope`]:crate::Lua::scope + /// [`TypeId`]: std::any::TypeId /// [`UserDataMethods`]: crate::UserDataMethods - pub fn create_nonstatic_userdata(&self, data: T) -> Result> + pub fn create_userdata(&'scope self, data: T) -> Result where - T: 'scope + UserData, + T: UserData + 'env, { - let data = Rc::new(RefCell::new(data)); - - // 'callback outliving 'scope is a lie to make the types work out, required due to the - // inability to work with the more correct callback type that is universally quantified over - // 'lua. This is safe though, because `UserData::add_methods` does not get to pick the 'lua - // lifetime, so none of the static methods UserData types can add can possibly capture - // parameters. - fn wrap_method<'scope, 'lua, 'callback: 'scope, T: 'scope>( - scope: &Scope<'lua, 'scope>, - data: Rc>, - ud_ptr: *const c_void, - method: NonStaticMethod<'callback, T>, - ) -> Result> { - // On methods that actually receive the userdata, we fake a type check on the passed in - // userdata, where we pretend there is a unique type per call to - // `Scope::create_nonstatic_userdata`. You can grab a method from a userdata and call - // it on a mismatched userdata type, which when using normal 'static userdata will fail - // with a type mismatch, but here without this check would proceed as though you had - // called the method on the original value (since we otherwise completely ignore the - // first argument). - let check_ud_type = move |lua: &'callback Lua, value| { - if let Some(Value::UserData(ud)) = value { - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - lua.push_userdata_ref(&ud.0)?; - if get_userdata(lua.state, -1) as *const _ == ud_ptr { - return Ok(()); - } - } - }; - Err(Error::UserDataTypeMismatch) - }; - - match method { - NonStaticMethod::Method(method) => { - let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - check_ud_type(lua, args.pop_front())?; - let data = data.try_borrow().map_err(|_| Error::UserDataBorrowError)?; - method(lua, &*data, args) - }); - unsafe { scope.create_callback(f) } - } - NonStaticMethod::MethodMut(method) => { - let method = RefCell::new(method); - let f = Box::new(move |lua, mut args: MultiValue<'callback>| { - check_ud_type(lua, args.pop_front())?; - let mut method = method - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - let mut data = data - .try_borrow_mut() - .map_err(|_| Error::UserDataBorrowMutError)?; - (*method)(lua, &mut *data, args) - }); - unsafe { scope.create_callback(f) } - } - NonStaticMethod::Function(function) => unsafe { scope.create_callback(function) }, - NonStaticMethod::FunctionMut(function) => { - let function = RefCell::new(function); - let f = Box::new(move |lua, args| { - (*function - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?)( - lua, args - ) - }); - unsafe { scope.create_callback(f) } - } - } - } - - let mut ud_fields = NonStaticUserDataFields::default(); - let mut ud_methods = NonStaticUserDataMethods::default(); - T::add_fields(&mut ud_fields); - T::add_methods(&mut ud_methods); - + let state = self.lua.state(); unsafe { - let lua = self.lua; - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 13)?; - - #[cfg(not(feature = "luau"))] - #[allow(clippy::let_and_return)] - let ud_ptr = protect_lua!(lua.state, 0, 1, |state| { - let ud = - ffi::lua_newuserdata(state, mem::size_of::>>>()); - - // Set empty environment for Lua 5.1 - #[cfg(any(feature = "lua51", feature = "luajit"))] - { - ffi::lua_newtable(state); - ffi::lua_setuservalue(state, -2); - } + let _sg = StackGuard::new(state); + check_stack(state, 3)?; - ud - })?; + // We don't write the data to the userdata until pushing the metatable + let protect = !self.lua.unlikely_memory_error(); #[cfg(feature = "luau")] let ud_ptr = { - crate::util::push_userdata::>>>( - lua.state, - UserDataCell::new(data.clone()), - )?; - ffi::lua_touserdata(lua.state, -1) + let data = UserDataStorage::new_scoped(data); + util::push_userdata(state, data, protect)? }; - - // Prepare metatable, add meta methods first and then meta fields - let meta_methods_nrec = ud_methods.meta_methods.len() + ud_fields.meta_fields.len() + 1; - push_table(lua.state, 0, meta_methods_nrec as c_int)?; - - for (k, m) in ud_methods.meta_methods { - let data = data.clone(); - lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?; - rawset_field(lua.state, -2, k.validate()?.name())?; - } - for (k, f) in ud_fields.meta_fields { - lua.push_value(f(mem::transmute(lua))?)?; - rawset_field(lua.state, -2, k.validate()?.name())?; - } - let metatable_index = ffi::lua_absindex(lua.state, -1); - - let mut field_getters_index = None; - let field_getters_nrec = ud_fields.field_getters.len(); - if field_getters_nrec > 0 { - push_table(lua.state, 0, field_getters_nrec as c_int)?; - for (k, m) in ud_fields.field_getters { - let data = data.clone(); - lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?; - rawset_field(lua.state, -2, &k)?; - } - field_getters_index = Some(ffi::lua_absindex(lua.state, -1)); - } - - let mut field_setters_index = None; - let field_setters_nrec = ud_fields.field_setters.len(); - if field_setters_nrec > 0 { - push_table(lua.state, 0, field_setters_nrec as c_int)?; - for (k, m) in ud_fields.field_setters { - let data = data.clone(); - lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?; - rawset_field(lua.state, -2, &k)?; - } - field_setters_index = Some(ffi::lua_absindex(lua.state, -1)); - } - - let mut methods_index = None; - let methods_nrec = ud_methods.methods.len(); - if methods_nrec > 0 { - // Create table used for methods lookup - push_table(lua.state, 0, methods_nrec as c_int)?; - for (k, m) in ud_methods.methods { - let data = data.clone(); - lua.push_value(Value::Function(wrap_method(self, data, ud_ptr, m)?))?; - rawset_field(lua.state, -2, &k)?; - } - methods_index = Some(ffi::lua_absindex(lua.state, -1)); - } - - init_userdata_metatable::>>>( - lua.state, - metatable_index, - field_getters_index, - field_setters_index, - methods_index, - )?; - - let count = field_getters_index.map(|_| 1).unwrap_or(0) - + field_setters_index.map(|_| 1).unwrap_or(0) - + methods_index.map(|_| 1).unwrap_or(0); - ffi::lua_pop(lua.state, count); - - let mt_ptr = ffi::lua_topointer(lua.state, -1); - // Write userdata just before attaching metatable with `__gc` metamethod #[cfg(not(feature = "luau"))] - std::ptr::write(ud_ptr as _, UserDataCell::new(data)); - ffi::lua_setmetatable(lua.state, -2); - let ud = AnyUserData(lua.pop_ref()); - lua.register_userdata_metatable(mt_ptr, None); - - #[cfg(any(feature = "lua51", feature = "luajit"))] - let newtable = lua.create_table()?; - let destructor: DestructorCallback = Box::new(move |ud| { - let state = ud.lua.state; - let _sg = StackGuard::new(state); - assert_stack(state, 2); - - // Check that userdata is valid (very likely) - if ud.lua.push_userdata_ref(&ud).is_err() { - return vec![]; - } + let ud_ptr = util::push_uninit_userdata::>(state, protect)?; - // Deregister metatable - ffi::lua_getmetatable(state, -1); - let mt_ptr = ffi::lua_topointer(state, -1); - ffi::lua_pop(state, 1); - ud.lua.deregister_userdata_metatable(mt_ptr); - - // Clear associated user values - #[cfg(feature = "lua54")] - for i in 1..=USER_VALUE_MAXSLOT { - ffi::lua_pushnil(state); - ffi::lua_setiuservalue(state, -2, i as c_int); - } - #[cfg(any(feature = "lua53", feature = "lua52", feature = "luau"))] - { - ffi::lua_pushnil(state); - ffi::lua_setuservalue(state, -2); - } - #[cfg(any(feature = "lua51", feature = "luajit"))] - { - ud.lua.push_ref(&newtable.0); - ffi::lua_setuservalue(state, -2); - } + // Push the metatable and register it with no TypeId + let mut registry = UserDataRegistry::new_unique(self.lua.lua(), ud_ptr as *mut _); + T::register(&mut registry); + self.lua.push_userdata_metatable(registry.into_raw())?; + let mt_ptr = ffi::lua_topointer(state, -1); + self.lua.register_userdata_metatable(mt_ptr, None); - // A hack to drop non-static `T` - unsafe fn seal(t: T) -> Box { - let f: Box = Box::new(move || drop(t)); - mem::transmute(f) - } + // Write data to the pointer and attach metatable + #[cfg(not(feature = "luau"))] + std::ptr::write(ud_ptr, UserDataStorage::new_scoped(data)); + ffi::lua_setmetatable(state, -2); - let ud = Box::new(seal(take_userdata::>>>(state))); - vec![ud] - }); - self.destructors - .borrow_mut() - .push((ud.0.clone(), destructor)); + let ud = AnyUserData(self.lua.pop_ref()); + self.seal_userdata::(&ud); Ok(ud) } } - // Unsafe, because the callback can improperly capture any value with 'callback scope, such as - // improperly capturing an argument. Since the 'callback lifetime is chosen by the user and the - // lifetime of the callback itself is 'scope (non-'static), the borrow checker will happily pick - // a 'callback that outlives 'scope to allow this. In order for this to be safe, the callback - // must NOT capture any parameters. - unsafe fn create_callback<'callback>( - &self, - f: Callback<'callback, 'scope>, - ) -> Result> { - let f = mem::transmute::, Callback<'lua, 'static>>(f); - let f = self.lua.create_callback(f)?; - - let destructor: DestructorCallback = Box::new(|f| { - let state = f.lua.state; + /// Creates a Lua userdata object from a custom Rust type. + /// + /// Since the Rust type is not required to be static and implement [`UserData`] trait, + /// you need to provide a function to register fields or methods for the object. + /// + /// See also [`Scope::create_userdata`] for more details about non-static limitations. + pub fn create_any_userdata( + &'scope self, + data: T, + register: impl FnOnce(&mut UserDataRegistry), + ) -> Result + where + T: 'env, + { + let state = self.lua.state(); + let ud = unsafe { let _sg = StackGuard::new(state); - assert_stack(state, 3); + check_stack(state, 3)?; + + // We don't write the data to the userdata until pushing the metatable + let protect = !self.lua.unlikely_memory_error(); + #[cfg(feature = "luau")] + let ud_ptr = { + let data = UserDataStorage::new_scoped(data); + util::push_userdata(state, data, protect)? + }; + #[cfg(not(feature = "luau"))] + let ud_ptr = util::push_uninit_userdata::>(state, protect)?; + + // Push the metatable and register it with no TypeId + let mut registry = UserDataRegistry::new_unique(self.lua.lua(), ud_ptr as *mut _); + register(&mut registry); + self.lua.push_userdata_metatable(registry.into_raw())?; + let mt_ptr = ffi::lua_topointer(state, -1); + self.lua.register_userdata_metatable(mt_ptr, None); + + // Write data to the pointer and attach metatable + #[cfg(not(feature = "luau"))] + std::ptr::write(ud_ptr, UserDataStorage::new_scoped(data)); + ffi::lua_setmetatable(state, -2); - f.lua.push_ref(&f); + AnyUserData(self.lua.pop_ref()) + }; + self.seal_userdata::(&ud); + Ok(ud) + } - // We know the destructor has not run yet because we hold a reference to the callback. + /// Adds a destructor function to be run when the scope ends. + /// + /// This functionality is useful for cleaning up any resources after the scope ends. + /// + /// # Example + /// + /// ```rust + /// # use mlua::{Error, Lua, Result}; + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// let ud = lua.create_any_userdata(String::from("hello"))?; + /// lua.scope(|scope| { + /// scope.add_destructor(|| { + /// _ = ud.take::(); + /// }); + /// // Run the code that uses `ud` here + /// Ok(()) + /// })?; + /// assert!(matches!(ud.borrow::(), Err(Error::UserDataDestructed))); + /// # Ok(()) + /// # } + pub fn add_destructor(&'scope self, destructor: impl FnOnce() + 'env) { + self.user_destructors.0.borrow_mut().push(Box::new(destructor)); + } - ffi::lua_getupvalue(state, -1, 1); - let ud = take_userdata::(state); - ffi::lua_pushnil(state); - ffi::lua_setupvalue(state, -2, 1); + unsafe fn create_callback(&'scope self, f: ScopedCallback<'scope>) -> Result { + let f = mem::transmute::(f); + let f = self.lua.create_callback(f)?; - vec![Box::new(ud)] + let destructor: DestructorCallback = Box::new(|rawlua, vref| { + let ref_thread = rawlua.ref_thread(); + ffi::lua_getupvalue(ref_thread, vref.index, 1); + let upvalue = get_userdata::(ref_thread, -1); + let data = (*upvalue).data.take(); + ffi::lua_pop(ref_thread, 1); + vec![Box::new(move || drop(data))] }); - self.destructors - .borrow_mut() - .push((f.0.clone(), destructor)); + self.destructors.0.borrow_mut().push((f.0.clone(), destructor)); Ok(f) } - #[cfg(feature = "async")] - unsafe fn create_async_callback<'callback>( - &self, - f: AsyncCallback<'callback, 'scope>, - ) -> Result> { - let f = mem::transmute::, AsyncCallback<'lua, 'static>>(f); - let f = self.lua.create_async_callback(f)?; - - // We need to pre-allocate strings to avoid failures in destructor. - let get_poll_str = self.lua.create_string("get_poll")?; - let poll_str = self.lua.create_string("poll")?; - let destructor: DestructorCallback = Box::new(move |f| { - let state = f.lua.state; - let _sg = StackGuard::new(state); - assert_stack(state, 5); - - f.lua.push_ref(&f); - - // We know the destructor has not run yet because we hold a reference to the callback. - - // First, get the environment table - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - ffi::lua_getupvalue(state, -1, 1); - #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] - ffi::lua_getfenv(state, -1); - - // Second, get the `get_poll()` closure using the corresponding key - f.lua.push_ref(&get_poll_str.0); - ffi::lua_rawget(state, -2); - - // Destroy all upvalues - ffi::lua_getupvalue(state, -1, 1); - let upvalue1 = take_userdata::(state); - ffi::lua_pushnil(state); - ffi::lua_setupvalue(state, -2, 1); - - ffi::lua_pop(state, 1); - let mut data: Vec> = vec![Box::new(upvalue1)]; - - // Finally, get polled future and destroy it - f.lua.push_ref(&poll_str.0); - if ffi::lua_rawget(state, -2) == ffi::LUA_TFUNCTION { - ffi::lua_getupvalue(state, -1, 1); - let upvalue2 = take_userdata::(state); - ffi::lua_pushnil(state); - ffi::lua_setupvalue(state, -2, 1); - data.push(Box::new(upvalue2)); + /// Shortens the lifetime of the userdata to the lifetime of the scope. + fn seal_userdata(&self, ud: &AnyUserData) { + let destructor: DestructorCallback = Box::new(|rawlua, vref| unsafe { + // Ensure that userdata is not destructed + match rawlua.get_userdata_ref_type_id(&vref) { + Ok(Some(_)) => {} + Ok(None) => { + // Deregister metatable + let mt_ptr = get_metatable_ptr(rawlua.ref_thread(), vref.index); + rawlua.deregister_userdata_metatable(mt_ptr); + } + Err(_) => return vec![], } - data + let data = take_userdata::>(rawlua.ref_thread(), vref.index); + vec![Box::new(move || drop(data))] }); - self.destructors - .borrow_mut() - .push((f.0.clone(), destructor)); - - Ok(f) + self.destructors.0.borrow_mut().push((ud.0.clone(), destructor)); } } -impl<'lua, 'scope> Drop for Scope<'lua, 'scope> { +impl Drop for Destructors<'_> { fn drop(&mut self) { // We separate the action of invalidating the userdata in Lua and actually dropping the - // userdata type into two phases. This is so that, in the event a userdata drop panics, we - // can be sure that all of the userdata in Lua is actually invalidated. - - // All destructors are non-panicking, so this is fine - let to_drop = self - .destructors - .get_mut() - .drain(..) - .flat_map(|(r, dest)| dest(r)) - .collect::>(); - - drop(to_drop); - } -} - -enum NonStaticMethod<'lua, T> { - Method(Box) -> Result>>), - MethodMut(Box) -> Result>>), - Function(Box) -> Result>>), - FunctionMut(Box) -> Result>>), -} - -struct NonStaticUserDataMethods<'lua, T: UserData> { - methods: Vec<(Vec, NonStaticMethod<'lua, T>)>, - meta_methods: Vec<(MetaMethod, NonStaticMethod<'lua, T>)>, -} - -impl<'lua, T: UserData> Default for NonStaticUserDataMethods<'lua, T> { - fn default() -> NonStaticUserDataMethods<'lua, T> { - NonStaticUserDataMethods { - methods: Vec::new(), - meta_methods: Vec::new(), + // userdata type into two phases. This is so that, in the event a userdata drop panics, + // we can be sure that all of the userdata in Lua is actually invalidated. + + let destructors = mem::take(&mut *self.0.borrow_mut()); + if let Some(lua) = destructors.first().map(|(vref, _)| vref.lua.lock()) { + // All destructors are non-panicking, so this is fine + let to_drop = destructors + .into_iter() + .flat_map(|(vref, destructor)| destructor(&lua, vref)) + .collect::>(); + + drop(to_drop); } } } -impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'lua, T> { - fn add_method(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::Method(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_method_mut(&mut self, name: &S, mut method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - #[cfg(feature = "async")] - fn add_async_method(&mut self, _name: &S, _method: M) - where - T: Clone, - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>, - { - // The panic should never happen as async non-static code wouldn't compile - // Non-static lifetime must be bounded to 'lua lifetime - mlua_panic!("asynchronous methods are not supported for non-static userdata") - } - - fn add_function(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::Function(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_function_mut(&mut self, name: &S, mut function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - self.methods.push(( - name.as_ref().to_vec(), - NonStaticMethod::FunctionMut(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - #[cfg(feature = "async")] - fn add_async_function(&mut self, _name: &S, _function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - // The panic should never happen as async non-static code wouldn't compile - // Non-static lifetime must be bounded to 'lua lifetime - mlua_panic!("asynchronous functions are not supported for non-static userdata") - } - - fn add_meta_method(&mut self, meta: S, method: M) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, - { - self.meta_methods.push(( - meta.into(), - NonStaticMethod::Method(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_method_mut(&mut self, meta: S, mut method: M) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.meta_methods.push(( - meta.into(), - NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_method(&mut self, _meta: S, _method: M) - where - T: Clone, - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>, - { - // The panic should never happen as async non-static code wouldn't compile - // Non-static lifetime must be bounded to 'lua lifetime - mlua_panic!("asynchronous meta methods are not supported for non-static userdata") - } - - fn add_meta_function(&mut self, meta: S, function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - self.meta_methods.push(( - meta.into(), - NonStaticMethod::Function(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_function_mut(&mut self, meta: S, mut function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - self.meta_methods.push(( - meta.into(), - NonStaticMethod::FunctionMut(Box::new(move |lua, args| { - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_function(&mut self, _meta: S, _function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - // The panic should never happen as async non-static code wouldn't compile - // Non-static lifetime must be bounded to 'lua lifetime - mlua_panic!("asynchronous meta functions are not supported for non-static userdata") - } -} - -struct NonStaticUserDataFields<'lua, T: UserData> { - field_getters: Vec<(Vec, NonStaticMethod<'lua, T>)>, - field_setters: Vec<(Vec, NonStaticMethod<'lua, T>)>, - #[allow(clippy::type_complexity)] - meta_fields: Vec<(MetaMethod, Box Result>>)>, -} - -impl<'lua, T: UserData> Default for NonStaticUserDataFields<'lua, T> { - fn default() -> NonStaticUserDataFields<'lua, T> { - NonStaticUserDataFields { - field_getters: Vec::new(), - field_setters: Vec::new(), - meta_fields: Vec::new(), +impl Drop for UserDestructors<'_> { + fn drop(&mut self) { + let destructors = mem::take(&mut *self.0.borrow_mut()); + for destructor in destructors { + destructor(); } } } - -impl<'lua, T: UserData> UserDataFields<'lua, T> for NonStaticUserDataFields<'lua, T> { - fn add_field_method_get(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result, - { - self.field_getters.push(( - name.as_ref().to_vec(), - NonStaticMethod::Method(Box::new(move |lua, ud, _| { - method(lua, ud)?.to_lua_multi(lua) - })), - )); - } - - fn add_field_method_set(&mut self, name: &S, mut method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>, - { - self.field_setters.push(( - name.as_ref().to_vec(), - NonStaticMethod::MethodMut(Box::new(move |lua, ud, args| { - method(lua, ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_field_function_get(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result, - { - self.field_getters.push(( - name.as_ref().to_vec(), - NonStaticMethod::Function(Box::new(move |lua, args| { - function(lua, AnyUserData::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - })), - )); - } - - fn add_field_function_set(&mut self, name: &S, mut function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>, - { - self.field_setters.push(( - name.as_ref().to_vec(), - NonStaticMethod::FunctionMut(Box::new(move |lua, args| { - let (ud, val) = <_>::from_lua_multi(args, lua)?; - function(lua, ud, val)?.to_lua_multi(lua) - })), - )); - } - - fn add_meta_field_with(&mut self, meta: S, f: F) - where - S: Into, - F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, - R: ToLua<'lua>, - { - let meta = meta.into(); - self.meta_fields.push(( - meta.clone(), - Box::new(move |lua| { - let value = f(lua)?.to_lua(lua)?; - if meta == MetaMethod::Index || meta == MetaMethod::NewIndex { - match value { - Value::Nil | Value::Table(_) | Value::Function(_) => {} - _ => { - return Err(Error::MetaMethodTypeError { - method: meta.to_string(), - type_name: value.type_name(), - message: Some("expected nil, table or function".to_string()), - }) - } - } - } - Ok(value) - }), - )); - } -} diff --git a/src/serde/de.rs b/src/serde/de.rs index 3e4bb14a..4c5eb1b6 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -1,37 +1,39 @@ +//! Deserialize Lua values to a Rust data structure. + use std::cell::RefCell; -use std::convert::TryInto; use std::os::raw::c_void; use std::rc::Rc; -use std::string::String as StdString; +use std::result::Result as StdResult; use rustc_hash::FxHashSet; use serde::de::{self, IntoDeserializer}; use crate::error::{Error, Result}; -use crate::ffi; use crate::table::{Table, TablePairs, TableSequence}; +use crate::userdata::AnyUserData; use crate::value::Value; /// A struct for deserializing Lua values into Rust values. -#[derive(Debug)] -pub struct Deserializer<'lua> { - value: Value<'lua>, +#[derive(Debug, Default)] +pub struct Deserializer { + value: Value, options: Options, visited: Rc>>, + len: Option, // A length hint for sequences } /// A struct with options to change default deserializer behavior. #[derive(Debug, Clone, Copy)] #[non_exhaustive] pub struct Options { - /// If true, an attempt to serialize types such as [`Thread`], [`UserData`], [`LightUserData`] + /// If true, an attempt to serialize types such as [`Function`], [`Thread`], [`LightUserData`] /// and [`Error`] will cause an error. /// Otherwise these types skipped when iterating or serialized as unit type. /// /// Default: **true** /// + /// [`Function`]: crate::Function /// [`Thread`]: crate::Thread - /// [`UserData`]: crate::UserData /// [`LightUserData`]: crate::LightUserData /// [`Error`]: crate::Error pub deny_unsupported_types: bool, @@ -42,11 +44,34 @@ pub struct Options { /// /// Default: **true** pub deny_recursive_tables: bool, + + /// If true, keys in tables will be iterated in sorted order. + /// + /// Default: **false** + pub sort_keys: bool, + + /// If true, empty Lua tables will be encoded as array, instead of map. + /// + /// Default: **false** + pub encode_empty_tables_as_array: bool, + + /// If true, enable detection of mixed tables. + /// + /// A mixed table is a table that has both array-like and map-like entries or several borders. + /// See [`The Length Operator`] documentation for details about borders. + /// + /// When this option is disabled, a table with a non-zero length (with one or more borders) will + /// be always encoded as an array. + /// + /// Default: **false** + /// + /// [`The Length Operator`]: https://www.lua.org/manual/5.4/manual.html#3.4.7 + pub detect_mixed_tables: bool, } impl Default for Options { fn default() -> Self { - Self::new() + const { Self::new() } } } @@ -56,6 +81,9 @@ impl Options { Options { deny_unsupported_types: true, deny_recursive_tables: true, + sort_keys: false, + encode_empty_tables_as_array: false, + detect_mixed_tables: false, } } @@ -72,41 +100,70 @@ impl Options { /// /// [`deny_recursive_tables`]: #structfield.deny_recursive_tables #[must_use] - pub fn deny_recursive_tables(mut self, enabled: bool) -> Self { + pub const fn deny_recursive_tables(mut self, enabled: bool) -> Self { self.deny_recursive_tables = enabled; self } + + /// Sets [`sort_keys`] option. + /// + /// [`sort_keys`]: #structfield.sort_keys + #[must_use] + pub const fn sort_keys(mut self, enabled: bool) -> Self { + self.sort_keys = enabled; + self + } + + /// Sets [`encode_empty_tables_as_array`] option. + /// + /// [`encode_empty_tables_as_array`]: #structfield.encode_empty_tables_as_array + #[must_use] + pub const fn encode_empty_tables_as_array(mut self, enabled: bool) -> Self { + self.encode_empty_tables_as_array = enabled; + self + } + + /// Sets [`detect_mixed_tables`] option. + /// + /// [`detect_mixed_tables`]: #structfield.detect_mixed_tables + #[must_use] + pub const fn detect_mixed_tables(mut self, enable: bool) -> Self { + self.detect_mixed_tables = enable; + self + } } -impl<'lua> Deserializer<'lua> { - /// Creates a new Lua Deserializer for the `Value`. - pub fn new(value: Value<'lua>) -> Self { +impl Deserializer { + /// Creates a new Lua Deserializer for the [`Value`]. + pub fn new(value: Value) -> Self { Self::new_with_options(value, Options::default()) } - /// Creates a new Lua Deserializer for the `Value` with custom options. - pub fn new_with_options(value: Value<'lua>, options: Options) -> Self { + /// Creates a new Lua Deserializer for the [`Value`] with custom options. + pub fn new_with_options(value: Value, options: Options) -> Self { Deserializer { value, options, - visited: Rc::new(RefCell::new(FxHashSet::default())), + ..Default::default() } } - fn from_parts( - value: Value<'lua>, - options: Options, - visited: Rc>>, - ) -> Self { + fn from_parts(value: Value, options: Options, visited: Rc>>) -> Self { Deserializer { value, options, visited, + ..Default::default() } } + + fn with_len(mut self, len: usize) -> Self { + self.len = Some(len); + self + } } -impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { +impl<'de> serde::Deserializer<'de> for Deserializer { type Error = Error; #[inline] @@ -118,30 +175,40 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { Value::Nil => visitor.visit_unit(), Value::Boolean(b) => visitor.visit_bool(b), #[allow(clippy::useless_conversion)] - Value::Integer(i) => { - visitor.visit_i64(i.try_into().expect("cannot convert lua_Integer to i64")) - } + Value::Integer(i) => visitor.visit_i64(i.into()), #[allow(clippy::useless_conversion)] Value::Number(n) => visitor.visit_f64(n.into()), #[cfg(feature = "luau")] - Value::Vector(_, _, _) => self.deserialize_seq(visitor), + Value::Vector(_) => self.deserialize_seq(visitor), Value::String(s) => match s.to_str() { - Ok(s) => visitor.visit_str(s), - Err(_) => visitor.visit_bytes(s.as_bytes()), + Ok(s) => visitor.visit_str(&s), + Err(_) => visitor.visit_bytes(&s.as_bytes()), }, - Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor), - Value::Table(_) => self.deserialize_map(visitor), + Value::Table(ref t) => { + if let Some(len) = t.encode_as_array(self.options) { + self.with_len(len).deserialize_seq(visitor) + } else { + self.deserialize_map(visitor) + } + } Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(), + Value::UserData(ud) if ud.is_serializable() => { + serde_userdata(ud, |value| value.deserialize_any(visitor)) + } + #[cfg(feature = "luau")] + Value::Buffer(buf) => { + let lua = buf.0.lua.lock(); + visitor.visit_bytes(buf.as_slice(&lua)) + } Value::Function(_) | Value::Thread(_) | Value::UserData(_) | Value::LightUserData(_) - | Value::Error(_) => { + | Value::Error(_) + | Value::Other(_) => { if self.options.deny_unsupported_types { - Err(de::Error::custom(format!( - "unsupported value type `{}`", - self.value.type_name() - ))) + let msg = format!("unsupported value type `{}`", self.value.type_name()); + Err(de::Error::custom(msg)) } else { visitor.visit_unit() } @@ -164,8 +231,8 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { #[inline] fn deserialize_enum( self, - _name: &str, - _variants: &'static [&'static str], + name: &'static str, + variants: &'static [&'static str], visitor: V, ) -> Result where @@ -175,14 +242,14 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { Value::Table(table) => { let _guard = RecursionGuard::new(&table, &self.visited); - let mut iter = table.pairs::(); + let mut iter = table.pairs::(); let (variant, value) = match iter.next() { Some(v) => v?, None => { return Err(de::Error::invalid_value( de::Unexpected::Map, &"map with a single key", - )) + )); } }; @@ -192,13 +259,18 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { &"map with a single key", )); } - if check_value_if_skip(&value, self.options, &self.visited)? { + let skip = check_value_for_skip(&value, self.options, &self.visited) + .map_err(|err| Error::DeserializeError(err.to_string()))?; + if skip { return Err(de::Error::custom("bad enum value")); } (variant, Some(value), Some(_guard)) } Value::String(variant) => (variant.to_str()?.to_owned(), None, None), + Value::UserData(ud) if ud.is_serializable() => { + return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor)); + } _ => return Err(de::Error::custom("bad enum value")), }; @@ -217,9 +289,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { { match self.value { #[cfg(feature = "luau")] - Value::Vector(x, y, z) => { + Value::Vector(vec) => { let mut deserializer = VecDeserializer { - vec: [x, y, z], + vec, next: 0, options: self.options, visited: self.visited, @@ -229,22 +301,22 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { Value::Table(t) => { let _guard = RecursionGuard::new(&t, &self.visited); - let len = t.raw_len() as usize; + let len = self.len.unwrap_or_else(|| t.raw_len()); let mut deserializer = SeqDeserializer { - seq: t.raw_sequence_values(), + seq: t.sequence_values().with_len(len), options: self.options, visited: self.visited, }; let seq = visitor.visit_seq(&mut deserializer)?; - if deserializer.seq.count() == 0 { + if deserializer.seq.next().is_none() { Ok(seq) } else { - Err(de::Error::invalid_length( - len, - &"fewer elements in the table", - )) + Err(de::Error::invalid_length(len, &"fewer elements in the table")) } } + Value::UserData(ud) if ud.is_serializable() => { + serde_userdata(ud, |value| value.deserialize_seq(visitor)) + } value => Err(de::Error::invalid_type( de::Unexpected::Other(value.type_name()), &"table", @@ -261,12 +333,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { } #[inline] - fn deserialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - visitor: V, - ) -> Result + fn deserialize_tuple_struct(self, _name: &'static str, _len: usize, visitor: V) -> Result where V: de::Visitor<'de>, { @@ -283,7 +350,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { let _guard = RecursionGuard::new(&t, &self.visited); let mut deserializer = MapDeserializer { - pairs: t.pairs(), + pairs: MapPairs::new(&t, self.options.sort_keys)?, value: None, options: self.options, visited: self.visited, @@ -300,6 +367,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { )) } } + Value::UserData(ud) if ud.is_serializable() => { + serde_userdata(ud, |value| value.deserialize_map(visitor)) + } value => Err(de::Error::invalid_type( de::Unexpected::Other(value.type_name()), &"table", @@ -320,20 +390,54 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { self.deserialize_map(visitor) } + #[inline] + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::UserData(ud) if ud.is_serializable() => { + serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor)) + } + _ => visitor.visit_newtype_struct(self), + } + } + + #[inline] + fn deserialize_unit(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(), + _ => self.deserialize_any(visitor), + } + } + + #[inline] + fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(), + _ => self.deserialize_any(visitor), + } + } + serde::forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes - byte_buf unit unit_struct newtype_struct - identifier ignored_any + byte_buf identifier ignored_any } } -struct SeqDeserializer<'lua> { - seq: TableSequence<'lua, Value<'lua>>, +struct SeqDeserializer<'a> { + seq: TableSequence<'a, Value>, options: Options, visited: Rc>>, } -impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> { +impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_> { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result> @@ -344,7 +448,9 @@ impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> { match self.seq.next() { Some(value) => { let value = value?; - if check_value_if_skip(&value, self.options, &self.visited)? { + let skip = check_value_for_skip(&value, self.options, &self.visited) + .map_err(|err| Error::DeserializeError(err.to_string()))?; + if skip { continue; } let visited = Rc::clone(&self.visited); @@ -366,7 +472,7 @@ impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> { #[cfg(feature = "luau")] struct VecDeserializer { - vec: [f32; 3], + vec: crate::Vector, next: usize, options: Options, visited: Rc>>, @@ -380,12 +486,11 @@ impl<'de> de::SeqAccess<'de> for VecDeserializer { where T: de::DeserializeSeed<'de>, { - match self.vec.get(self.next) { + match self.vec.0.get(self.next) { Some(&n) => { self.next += 1; let visited = Rc::clone(&self.visited); - let deserializer = - Deserializer::from_parts(Value::Number(n as _), self.options, visited); + let deserializer = Deserializer::from_parts(Value::Number(n as _), self.options, visited); seed.deserialize(deserializer).map(Some) } None => Ok(None), @@ -393,57 +498,118 @@ impl<'de> de::SeqAccess<'de> for VecDeserializer { } fn size_hint(&self) -> Option { - Some(3) + Some(crate::Vector::SIZE) + } +} + +pub(crate) enum MapPairs<'a> { + Iter(TablePairs<'a, Value, Value>), + Vec(Vec<(Value, Value)>), +} + +impl<'a> MapPairs<'a> { + pub(crate) fn new(t: &'a Table, sort_keys: bool) -> Result { + if sort_keys { + let mut pairs = t.pairs::().collect::>>()?; + pairs.sort_by(|(a, _), (b, _)| b.sort_cmp(a)); // reverse order as we pop values from the end + Ok(MapPairs::Vec(pairs)) + } else { + Ok(MapPairs::Iter(t.pairs::())) + } + } + + pub(crate) fn count(self) -> usize { + match self { + MapPairs::Iter(iter) => iter.count(), + MapPairs::Vec(vec) => vec.len(), + } + } + + pub(crate) fn size_hint(&self) -> (usize, Option) { + match self { + MapPairs::Iter(iter) => iter.size_hint(), + MapPairs::Vec(vec) => (vec.len(), Some(vec.len())), + } } } -struct MapDeserializer<'lua> { - pairs: TablePairs<'lua, Value<'lua>, Value<'lua>>, - value: Option>, +impl Iterator for MapPairs<'_> { + type Item = Result<(Value, Value)>; + + fn next(&mut self) -> Option { + match self { + MapPairs::Iter(iter) => iter.next(), + MapPairs::Vec(vec) => vec.pop().map(Ok), + } + } +} + +struct MapDeserializer<'a> { + pairs: MapPairs<'a>, + value: Option, options: Options, visited: Rc>>, processed: usize, } -impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> { - type Error = Error; - - fn next_key_seed(&mut self, seed: T) -> Result> - where - T: de::DeserializeSeed<'de>, - { +impl MapDeserializer<'_> { + fn next_key_deserializer(&mut self) -> Result> { loop { match self.pairs.next() { Some(item) => { let (key, value) = item?; - if check_value_if_skip(&key, self.options, &self.visited)? - || check_value_if_skip(&value, self.options, &self.visited)? - { + let skip_key = check_value_for_skip(&key, self.options, &self.visited) + .map_err(|err| Error::DeserializeError(err.to_string()))?; + let skip_value = check_value_for_skip(&value, self.options, &self.visited) + .map_err(|err| Error::DeserializeError(err.to_string()))?; + if skip_key || skip_value { continue; } self.processed += 1; self.value = Some(value); let visited = Rc::clone(&self.visited); let key_de = Deserializer::from_parts(key, self.options, visited); - return seed.deserialize(key_de).map(Some); + return Ok(Some(key_de)); } None => return Ok(None), } } } - fn next_value_seed(&mut self, seed: T) -> Result - where - T: de::DeserializeSeed<'de>, - { + fn next_value_deserializer(&mut self) -> Result { match self.value.take() { Some(value) => { let visited = Rc::clone(&self.visited); - seed.deserialize(Deserializer::from_parts(value, self.options, visited)) + Ok(Deserializer::from_parts(value, self.options, visited)) } None => Err(de::Error::custom("value is missing")), } } +} + +impl<'de> de::MapAccess<'de> for MapDeserializer<'_> { + type Error = Error; + + fn next_key_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + match self.next_key_deserializer() { + Ok(Some(key_de)) => seed.deserialize(key_de).map(Some), + Ok(None) => Ok(None), + Err(error) => Err(error), + } + } + + fn next_value_seed(&mut self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + match self.next_value_deserializer() { + Ok(value_de) => seed.deserialize(value_de), + Err(error) => Err(error), + } + } fn size_hint(&self) -> Option { match self.pairs.size_hint() { @@ -453,16 +619,16 @@ impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> { } } -struct EnumDeserializer<'lua> { - variant: StdString, - value: Option>, +struct EnumDeserializer { + variant: String, + value: Option, options: Options, visited: Rc>>, } -impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> { +impl<'de> de::EnumAccess<'de> for EnumDeserializer { type Error = Error; - type Variant = VariantDeserializer<'lua>; + type Variant = VariantDeserializer; fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant)> where @@ -478,13 +644,13 @@ impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> { } } -struct VariantDeserializer<'lua> { - value: Option>, +struct VariantDeserializer { + value: Option, options: Options, visited: Rc>>, } -impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> { +impl<'de> de::VariantAccess<'de> for VariantDeserializer { type Error = Error; fn unit_variant(self) -> Result<()> { @@ -502,9 +668,7 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> { T: de::DeserializeSeed<'de>, { match self.value { - Some(value) => { - seed.deserialize(Deserializer::from_parts(value, self.options, self.visited)) - } + Some(value) => seed.deserialize(Deserializer::from_parts(value, self.options, self.visited)), None => Err(de::Error::invalid_type( de::Unexpected::UnitVariant, &"newtype variant", @@ -547,18 +711,16 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> { // Adds `ptr` to the `visited` map and removes on drop // Used to track recursive tables but allow to traverse same tables multiple times -struct RecursionGuard { +pub(crate) struct RecursionGuard { ptr: *const c_void, visited: Rc>>, } impl RecursionGuard { #[inline] - fn new(table: &Table, visited: &Rc>>) -> Self { + pub(crate) fn new(table: &Table, visited: &Rc>>) -> Self { let visited = Rc::clone(visited); - let lua = table.0.lua; - let ptr = - unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) }; + let ptr = table.to_pointer(); visited.borrow_mut().insert(ptr); RecursionGuard { ptr, visited } } @@ -571,23 +733,22 @@ impl Drop for RecursionGuard { } // Checks `options` and decides should we emit an error or skip next element -fn check_value_if_skip( +pub(crate) fn check_value_for_skip( value: &Value, options: Options, visited: &RefCell>, -) -> Result { +) -> StdResult { match value { Value::Table(table) => { - let lua = table.0.lua; - let ptr = - unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) }; + let ptr = table.to_pointer(); if visited.borrow().contains(&ptr) { if options.deny_recursive_tables { - return Err(de::Error::custom("recursive table detected")); + return Err("recursive table detected"); } return Ok(true); // skip } } + Value::UserData(ud) if ud.is_serializable() => {} Value::Function(_) | Value::Thread(_) | Value::UserData(_) @@ -601,3 +762,16 @@ fn check_value_if_skip( } Ok(false) // do not skip } + +fn serde_userdata( + ud: AnyUserData, + f: impl FnOnce(serde_value::Value) -> std::result::Result, +) -> Result { + match serde_value::to_value(ud) { + Ok(value) => match f(value) { + Ok(r) => Ok(r), + Err(error) => Err(Error::DeserializeError(error.to_string())), + }, + Err(error) => Err(Error::SerializeError(error.to_string())), + } +} diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 74f0681b..2d39f59e 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,25 +1,22 @@ //! (De)Serialization support using serde. use std::os::raw::c_void; -use std::ptr; -use serde::{Deserialize, Serialize}; +use serde::de::DeserializeOwned; +use serde::ser::Serialize; use crate::error::Result; -use crate::ffi; -use crate::lua::Lua; +use crate::private::Sealed; +use crate::state::Lua; use crate::table::Table; -use crate::types::LightUserData; -use crate::util::{assert_stack, check_stack, StackGuard}; +use crate::util::check_stack; use crate::value::Value; /// Trait for serializing/deserializing Lua values using Serde. -#[cfg_attr(docsrs, doc(cfg(feature = "serialize")))] -pub trait LuaSerdeExt<'lua> { +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] +pub trait LuaSerdeExt: Sealed { /// A special value (lightuserdata) to encode/decode optional (none) values. /// - /// Requires `feature = "serialize"` - /// /// # Example /// /// ``` @@ -37,13 +34,11 @@ pub trait LuaSerdeExt<'lua> { /// Ok(()) /// } /// ``` - fn null(&'lua self) -> Value<'lua>; + fn null(&self) -> Value; /// A metatable attachable to a Lua table to systematically encode it as Array (instead of Map). - /// As result, encoded Array will contain only sequence part of the table, with the same length - /// as the `#` operator on that table. - /// - /// Requires `feature = "serialize"` + /// As a result, encoded Array will contain only sequence part of the table, with the same + /// length as the `#` operator on that table. /// /// # Example /// @@ -68,12 +63,10 @@ pub trait LuaSerdeExt<'lua> { /// Ok(()) /// } /// ``` - fn array_metatable(&'lua self) -> Table<'lua>; + fn array_metatable(&self) -> Table; /// Converts `T` into a [`Value`] instance. /// - /// Requires `feature = "serialize"` - /// /// [`Value`]: crate::Value /// /// # Example @@ -101,14 +94,10 @@ pub trait LuaSerdeExt<'lua> { /// "#).exec() /// } /// ``` - fn to_value(&'lua self, t: &T) -> Result>; + fn to_value(&self, t: &T) -> Result; /// Converts `T` into a [`Value`] instance with options. /// - /// Requires `feature = "serialize"` - /// - /// [`Value`]: crate::Value - /// /// # Example /// /// ``` @@ -126,16 +115,12 @@ pub trait LuaSerdeExt<'lua> { /// "#).exec() /// } /// ``` - fn to_value_with(&'lua self, t: &T, options: ser::Options) -> Result> + fn to_value_with(&self, t: &T, options: ser::Options) -> Result where T: Serialize + ?Sized; /// Deserializes a [`Value`] into any serde deserializable object. /// - /// Requires `feature = "serialize"` - /// - /// [`Value`]: crate::Value - /// /// # Example /// /// ``` @@ -159,14 +144,10 @@ pub trait LuaSerdeExt<'lua> { /// } /// ``` #[allow(clippy::wrong_self_convention)] - fn from_value>(&'lua self, value: Value<'lua>) -> Result; + fn from_value(&self, value: Value) -> Result; /// Deserializes a [`Value`] into any serde deserializable object with options. /// - /// Requires `feature = "serialize"` - /// - /// [`Value`]: crate::Value - /// /// # Example /// /// ``` @@ -191,53 +172,46 @@ pub trait LuaSerdeExt<'lua> { /// } /// ``` #[allow(clippy::wrong_self_convention)] - fn from_value_with>( - &'lua self, - value: Value<'lua>, - options: de::Options, - ) -> Result; + fn from_value_with(&self, value: Value, options: de::Options) -> Result; } -impl<'lua> LuaSerdeExt<'lua> for Lua { - fn null(&'lua self) -> Value<'lua> { - Value::LightUserData(LightUserData(ptr::null_mut())) +impl LuaSerdeExt for Lua { + fn null(&self) -> Value { + Value::NULL } - fn array_metatable(&'lua self) -> Table<'lua> { + fn array_metatable(&self) -> Table { + let lua = self.lock(); unsafe { - let _sg = StackGuard::new(self.state); - assert_stack(self.state, 1); - - push_array_metatable(self.state); - - Table(self.pop_ref()) + push_array_metatable(lua.ref_thread()); + Table(lua.pop_ref_thread()) } } - fn to_value(&'lua self, t: &T) -> Result> + fn to_value(&self, t: &T) -> Result where T: Serialize + ?Sized, { t.serialize(ser::Serializer::new(self)) } - fn to_value_with(&'lua self, t: &T, options: ser::Options) -> Result> + fn to_value_with(&self, t: &T, options: ser::Options) -> Result where T: Serialize + ?Sized, { t.serialize(ser::Serializer::new_with_options(self, options)) } - fn from_value(&'lua self, value: Value<'lua>) -> Result + fn from_value(&self, value: Value) -> Result where - T: Deserialize<'lua>, + T: DeserializeOwned, { T::deserialize(de::Deserializer::new(value)) } - fn from_value_with(&'lua self, value: Value<'lua>, options: de::Options) -> Result + fn from_value_with(&self, value: Value, options: de::Options) -> Result where - T: Deserialize<'lua>, + T: DeserializeOwned, { T::deserialize(de::Deserializer::new_with_options(value, options)) } @@ -268,7 +242,5 @@ static ARRAY_METATABLE_REGISTRY_KEY: u8 = 0; pub mod de; pub mod ser; -#[doc(inline)] -pub use de::Deserializer; -#[doc(inline)] -pub use ser::Serializer; +pub use de::{Deserializer, Options as DeserializeOptions}; +pub use ser::{Options as SerializeOptions, Serializer}; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index cb2141ec..d14a7189 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -1,21 +1,18 @@ -use std::os::raw::c_int; +//! Serialize a Rust data structure into Lua value. -use serde::{ser, Serialize}; +use serde::{Serialize, ser}; use super::LuaSerdeExt; use crate::error::{Error, Result}; -use crate::ffi; -use crate::lua::Lua; -use crate::string::String; +use crate::state::Lua; use crate::table::Table; -use crate::types::Integer; -use crate::util::{check_stack, StackGuard}; -use crate::value::{ToLua, Value}; +use crate::traits::IntoLua; +use crate::value::Value; /// A struct for serializing Rust values into Lua values. #[derive(Debug)] -pub struct Serializer<'lua> { - lua: &'lua Lua, +pub struct Serializer<'a> { + lua: &'a Lua, options: Options, } @@ -48,11 +45,17 @@ pub struct Options { /// [`null`]: crate::LuaSerdeExt::null /// [`Nil`]: crate::Value::Nil pub serialize_unit_to_null: bool, + + /// If true, serialize `serde_json::Number` with arbitrary_precision to a Lua number. + /// Otherwise it will be serialized as an object (what serde does). + /// + /// Default: **false** + pub detect_serde_json_arbitrary_precision: bool, } impl Default for Options { fn default() -> Self { - Self::new() + const { Self::new() } } } @@ -63,6 +66,7 @@ impl Options { set_array_metatable: true, serialize_none_to_null: true, serialize_unit_to_null: true, + detect_serde_json_arbitrary_precision: false, } } @@ -92,16 +96,30 @@ impl Options { self.serialize_unit_to_null = enabled; self } + + /// Sets [`detect_serde_json_arbitrary_precision`] option. + /// + /// This option is used to serialize `serde_json::Number` with arbitrary precision to a Lua + /// number. Otherwise it will be serialized as an object (what serde does). + /// + /// This option is disabled by default. + /// + /// [`detect_serde_json_arbitrary_precision`]: #structfield.detect_serde_json_arbitrary_precision + #[must_use] + pub const fn detect_serde_json_arbitrary_precision(mut self, enabled: bool) -> Self { + self.detect_serde_json_arbitrary_precision = enabled; + self + } } -impl<'lua> Serializer<'lua> { +impl<'a> Serializer<'a> { /// Creates a new Lua Serializer with default options. - pub fn new(lua: &'lua Lua) -> Self { + pub fn new(lua: &'a Lua) -> Self { Self::new_with_options(lua, Options::default()) } /// Creates a new Lua Serializer with custom options. - pub fn new_with_options(lua: &'lua Lua, options: Options) -> Self { + pub fn new_with_options(lua: &'a Lua, options: Options) -> Self { Serializer { lua, options } } } @@ -109,28 +127,28 @@ impl<'lua> Serializer<'lua> { macro_rules! lua_serialize_number { ($name:ident, $t:ty) => { #[inline] - fn $name(self, value: $t) -> Result> { - value.to_lua(self.lua) + fn $name(self, value: $t) -> Result { + value.into_lua(self.lua) } }; } -impl<'lua> ser::Serializer for Serializer<'lua> { - type Ok = Value<'lua>; +impl<'a> ser::Serializer for Serializer<'a> { + type Ok = Value; type Error = Error; // Associated types for keeping track of additional state while serializing // compound data structures like sequences and maps. - type SerializeSeq = SerializeVec<'lua>; - type SerializeTuple = SerializeVec<'lua>; - type SerializeTupleStruct = SerializeVec<'lua>; - type SerializeTupleVariant = SerializeTupleVariant<'lua>; - type SerializeMap = SerializeMap<'lua>; - type SerializeStruct = SerializeMap<'lua>; - type SerializeStructVariant = SerializeStructVariant<'lua>; + type SerializeSeq = SerializeSeq<'a>; + type SerializeTuple = SerializeSeq<'a>; + type SerializeTupleStruct = SerializeSeq<'a>; + type SerializeTupleVariant = SerializeTupleVariant<'a>; + type SerializeMap = SerializeMap<'a>; + type SerializeStruct = SerializeStruct<'a>; + type SerializeStructVariant = SerializeStructVariant<'a>; #[inline] - fn serialize_bool(self, value: bool) -> Result> { + fn serialize_bool(self, value: bool) -> Result { Ok(Value::Boolean(value)) } @@ -149,22 +167,22 @@ impl<'lua> ser::Serializer for Serializer<'lua> { lua_serialize_number!(serialize_f64, f64); #[inline] - fn serialize_char(self, value: char) -> Result> { + fn serialize_char(self, value: char) -> Result { self.serialize_str(&value.to_string()) } #[inline] - fn serialize_str(self, value: &str) -> Result> { + fn serialize_str(self, value: &str) -> Result { self.lua.create_string(value).map(Value::String) } #[inline] - fn serialize_bytes(self, value: &[u8]) -> Result> { + fn serialize_bytes(self, value: &[u8]) -> Result { self.lua.create_string(value).map(Value::String) } #[inline] - fn serialize_none(self) -> Result> { + fn serialize_none(self) -> Result { if self.options.serialize_none_to_null { Ok(self.lua.null()) } else { @@ -173,7 +191,7 @@ impl<'lua> ser::Serializer for Serializer<'lua> { } #[inline] - fn serialize_some(self, value: &T) -> Result> + fn serialize_some(self, value: &T) -> Result where T: Serialize + ?Sized, { @@ -181,7 +199,7 @@ impl<'lua> ser::Serializer for Serializer<'lua> { } #[inline] - fn serialize_unit(self) -> Result> { + fn serialize_unit(self) -> Result { if self.options.serialize_unit_to_null { Ok(self.lua.null()) } else { @@ -190,7 +208,7 @@ impl<'lua> ser::Serializer for Serializer<'lua> { } #[inline] - fn serialize_unit_struct(self, _name: &'static str) -> Result> { + fn serialize_unit_struct(self, _name: &'static str) -> Result { if self.options.serialize_unit_to_null { Ok(self.lua.null()) } else { @@ -204,12 +222,12 @@ impl<'lua> ser::Serializer for Serializer<'lua> { _name: &'static str, _variant_index: u32, variant: &'static str, - ) -> Result> { + ) -> Result { self.serialize_str(variant) } #[inline] - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result> + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result where T: Serialize + ?Sized, { @@ -223,7 +241,7 @@ impl<'lua> ser::Serializer for Serializer<'lua> { _variant_index: u32, variant: &'static str, value: &T, - ) -> Result> + ) -> Result where T: Serialize + ?Sized, { @@ -236,13 +254,11 @@ impl<'lua> ser::Serializer for Serializer<'lua> { #[inline] fn serialize_seq(self, len: Option) -> Result { - let len = len.unwrap_or(0) as c_int; - let table = self.lua.create_table_with_capacity(len, 0)?; + let table = self.lua.create_table_with_capacity(len.unwrap_or(0), 0)?; if self.options.set_array_metatable { - table.set_metatable(Some(self.lua.array_metatable())); + table.set_metatable(Some(self.lua.array_metatable()))?; } - let options = self.options; - Ok(SerializeVec { table, options }) + Ok(SerializeSeq::new(self.lua, table, self.options)) } #[inline] @@ -251,11 +267,12 @@ impl<'lua> ser::Serializer for Serializer<'lua> { } #[inline] - fn serialize_tuple_struct( - self, - _name: &'static str, - len: usize, - ) -> Result { + fn serialize_tuple_struct(self, name: &'static str, len: usize) -> Result { + #[cfg(feature = "luau")] + if name == "Vector" && len == crate::Vector::SIZE { + return Ok(SerializeSeq::new_vector(self.lua, self.options)); + } + _ = name; self.serialize_seq(Some(len)) } @@ -268,7 +285,8 @@ impl<'lua> ser::Serializer for Serializer<'lua> { _len: usize, ) -> Result { Ok(SerializeTupleVariant { - name: self.lua.create_string(variant)?, + lua: self.lua, + variant, table: self.lua.create_table()?, options: self.options, }) @@ -276,17 +294,32 @@ impl<'lua> ser::Serializer for Serializer<'lua> { #[inline] fn serialize_map(self, len: Option) -> Result { - let len = len.unwrap_or(0) as c_int; Ok(SerializeMap { + lua: self.lua, key: None, - table: self.lua.create_table_with_capacity(0, len)?, + table: self.lua.create_table_with_capacity(0, len.unwrap_or(0))?, options: self.options, }) } #[inline] - fn serialize_struct(self, _name: &'static str, len: usize) -> Result { - self.serialize_map(Some(len)) + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + if self.options.detect_serde_json_arbitrary_precision + && name == "$serde_json::private::Number" + && len == 1 + { + return Ok(SerializeStruct { + lua: self.lua, + inner: None, + options: self.options, + }); + } + + Ok(SerializeStruct { + lua: self.lua, + inner: Some(Value::Table(self.lua.create_table_with_capacity(0, len)?)), + options: self.options, + }) } #[inline] @@ -298,49 +331,70 @@ impl<'lua> ser::Serializer for Serializer<'lua> { len: usize, ) -> Result { Ok(SerializeStructVariant { - name: self.lua.create_string(variant)?, - table: self.lua.create_table_with_capacity(0, len as c_int)?, + lua: self.lua, + variant, + table: self.lua.create_table_with_capacity(0, len)?, options: self.options, }) } } #[doc(hidden)] -pub struct SerializeVec<'lua> { - table: Table<'lua>, +pub struct SerializeSeq<'a> { + lua: &'a Lua, + #[cfg(feature = "luau")] + vector: Option, + table: Option
, + next: usize, options: Options, } -impl<'lua> ser::SerializeSeq for SerializeVec<'lua> { - type Ok = Value<'lua>; +impl<'a> SerializeSeq<'a> { + fn new(lua: &'a Lua, table: Table, options: Options) -> Self { + Self { + lua, + #[cfg(feature = "luau")] + vector: None, + table: Some(table), + next: 0, + options, + } + } + + #[cfg(feature = "luau")] + const fn new_vector(lua: &'a Lua, options: Options) -> Self { + Self { + lua, + vector: Some(crate::Vector::zero()), + table: None, + next: 0, + options, + } + } +} + +impl ser::SerializeSeq for SerializeSeq<'_> { + type Ok = Value; type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> where T: Serialize + ?Sized, { - let lua = self.table.0.lua; - let value = lua.to_value_with(value, self.options)?; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; - - lua.push_ref(&self.table.0); - lua.push_value(value)?; - protect_lua!(lua.state, 2, 0, fn(state) { - let len = ffi::lua_rawlen(state, -2) as Integer; - ffi::lua_rawseti(state, -2, len + 1); - }) - } + let value = self.lua.to_value_with(value, self.options)?; + let table = self.table.as_ref().unwrap(); + table.raw_seti(self.next + 1, value)?; + self.next += 1; + Ok(()) } - fn end(self) -> Result> { - Ok(Value::Table(self.table)) + fn end(self) -> Result { + Ok(Value::Table(self.table.unwrap())) } } -impl<'lua> ser::SerializeTuple for SerializeVec<'lua> { - type Ok = Value<'lua>; +impl ser::SerializeTuple for SerializeSeq<'_> { + type Ok = Value; type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> @@ -350,73 +404,82 @@ impl<'lua> ser::SerializeTuple for SerializeVec<'lua> { ser::SerializeSeq::serialize_element(self, value) } - fn end(self) -> Result> { + fn end(self) -> Result { ser::SerializeSeq::end(self) } } -impl<'lua> ser::SerializeTupleStruct for SerializeVec<'lua> { - type Ok = Value<'lua>; +impl ser::SerializeTupleStruct for SerializeSeq<'_> { + type Ok = Value; type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> where T: Serialize + ?Sized, { + #[cfg(feature = "luau")] + if let Some(vector) = self.vector.as_mut() { + let value = self.lua.to_value_with(value, self.options)?; + let value = self.lua.unpack(value)?; + vector.0[self.next] = value; + self.next += 1; + return Ok(()); + } ser::SerializeSeq::serialize_element(self, value) } - fn end(self) -> Result> { + fn end(self) -> Result { + #[cfg(feature = "luau")] + if let Some(vector) = self.vector { + return Ok(Value::Vector(vector)); + } ser::SerializeSeq::end(self) } } #[doc(hidden)] -pub struct SerializeTupleVariant<'lua> { - name: String<'lua>, - table: Table<'lua>, +pub struct SerializeTupleVariant<'a> { + lua: &'a Lua, + variant: &'static str, + table: Table, options: Options, } -impl<'lua> ser::SerializeTupleVariant for SerializeTupleVariant<'lua> { - type Ok = Value<'lua>; +impl ser::SerializeTupleVariant for SerializeTupleVariant<'_> { + type Ok = Value; type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> where T: Serialize + ?Sized, { - let lua = self.table.0.lua; - let idx = self.table.raw_len() + 1; - self.table - .raw_insert(idx, lua.to_value_with(value, self.options)?) + self.table.raw_push(self.lua.to_value_with(value, self.options)?) } - fn end(self) -> Result> { - let lua = self.table.0.lua; - let table = lua.create_table()?; - table.raw_set(self.name, self.table)?; + fn end(self) -> Result { + let table = self.lua.create_table()?; + table.raw_set(self.variant, self.table)?; Ok(Value::Table(table)) } } #[doc(hidden)] -pub struct SerializeMap<'lua> { - table: Table<'lua>, - key: Option>, +pub struct SerializeMap<'a> { + lua: &'a Lua, + table: Table, + key: Option, options: Options, } -impl<'lua> ser::SerializeMap for SerializeMap<'lua> { - type Ok = Value<'lua>; +impl ser::SerializeMap for SerializeMap<'_> { + type Ok = Value; type Error = Error; fn serialize_key(&mut self, key: &T) -> Result<()> where T: Serialize + ?Sized, { - let lua = self.table.0.lua; - self.key = Some(lua.to_value_with(key, self.options)?); + self.key = Some(self.lua.to_value_with(key, self.options)?); Ok(()) } @@ -424,62 +487,90 @@ impl<'lua> ser::SerializeMap for SerializeMap<'lua> { where T: Serialize + ?Sized, { - let lua = self.table.0.lua; - let key = mlua_expect!( - self.key.take(), - "serialize_value called before serialize_key" - ); - let value = lua.to_value_with(value, self.options)?; + let key = mlua_expect!(self.key.take(), "serialize_value called before serialize_key"); + let value = self.lua.to_value_with(value, self.options)?; self.table.raw_set(key, value) } - fn end(self) -> Result> { + fn end(self) -> Result { Ok(Value::Table(self.table)) } } -impl<'lua> ser::SerializeStruct for SerializeMap<'lua> { - type Ok = Value<'lua>; +#[doc(hidden)] +pub struct SerializeStruct<'a> { + lua: &'a Lua, + inner: Option, + options: Options, +} + +impl ser::SerializeStruct for SerializeStruct<'_> { + type Ok = Value; type Error = Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> where T: Serialize + ?Sized, { - ser::SerializeMap::serialize_key(self, key)?; - ser::SerializeMap::serialize_value(self, value) + match self.inner { + Some(Value::Table(ref table)) => { + table.raw_set(key, self.lua.to_value_with(value, self.options)?)?; + } + None if self.options.detect_serde_json_arbitrary_precision => { + // A special case for `serde_json::Number` with arbitrary precision. + assert_eq!(key, "$serde_json::private::Number"); + self.inner = Some(self.lua.to_value_with(value, self.options)?); + } + _ => unreachable!(), + } + Ok(()) } - fn end(self) -> Result> { - ser::SerializeMap::end(self) + fn end(self) -> Result { + match self.inner { + Some(table @ Value::Table(_)) => Ok(table), + Some(value @ Value::String(_)) if self.options.detect_serde_json_arbitrary_precision => { + let number_s = value.to_string()?; + if number_s.contains(['.', 'e', 'E']) + && let Ok(number) = number_s.parse().map(Value::Number) + { + return Ok(number); + } + Ok(number_s + .parse() + .map(Value::Integer) + .or_else(|_| number_s.parse().map(Value::Number)) + .unwrap_or(value)) + } + _ => unreachable!(), + } } } #[doc(hidden)] -pub struct SerializeStructVariant<'lua> { - name: String<'lua>, - table: Table<'lua>, +pub struct SerializeStructVariant<'a> { + lua: &'a Lua, + variant: &'static str, + table: Table, options: Options, } -impl<'lua> ser::SerializeStructVariant for SerializeStructVariant<'lua> { - type Ok = Value<'lua>; +impl ser::SerializeStructVariant for SerializeStructVariant<'_> { + type Ok = Value; type Error = Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> where T: Serialize + ?Sized, { - let lua = self.table.0.lua; self.table - .raw_set(key, lua.to_value_with(value, self.options)?)?; + .raw_set(key, self.lua.to_value_with(value, self.options)?)?; Ok(()) } - fn end(self) -> Result> { - let lua = self.table.0.lua; - let table = lua.create_table()?; - table.raw_set(self.name, self.table)?; + fn end(self) -> Result { + let table = self.lua.create_table_with_capacity(0, 1)?; + table.raw_set(self.variant, self.table)?; Ok(Value::Table(table)) } } diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 00000000..7e624ec2 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,2475 @@ +//! Lua state management. +//! +//! This module provides the main [`Lua`] state handle together with state-specific +//! configuration and garbage collector controls. + +use std::any::TypeId; +use std::cell::{BorrowError, BorrowMutError, RefCell}; +use std::marker::PhantomData; +use std::ops::Deref; +use std::os::raw::{c_char, c_int}; +use std::panic::Location; +use std::result::Result as StdResult; +use std::{fmt, mem, ptr}; + +use crate::chunk::{AsChunk, Chunk}; +use crate::debug::Debug; +use crate::error::{Error, Result}; +use crate::function::Function; +use crate::memory::MemoryState; +use crate::multi::MultiValue; +use crate::scope::Scope; +use crate::stdlib::StdLib; +use crate::string::LuaString; +use crate::table::Table; +use crate::thread::Thread; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::types::{ + AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, MaybeSync, Number, + ReentrantMutex, ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, +}; +use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage}; +use crate::util::{StackGuard, assert_stack, check_stack, protect_lua_closure, push_string, rawset_field}; +use crate::value::{Nil, Value}; + +#[cfg(not(feature = "luau"))] +use crate::{debug::HookTriggers, types::HookKind}; + +#[cfg(any(feature = "luau", doc))] +use crate::{buffer::Buffer, chunk::Compiler}; + +#[cfg(feature = "async")] +use { + crate::types::LightUserData, + std::future::{self, Future}, + std::task::Poll, +}; + +#[cfg(feature = "serde")] +use serde::Serialize; + +pub(crate) use extra::ExtraData; +#[doc(hidden)] +pub use raw::RawLua; +pub(crate) use util::callback_error_ext; + +/// Top level Lua struct which represents an instance of Lua VM. +pub struct Lua { + pub(self) raw: XRc>, + // Controls whether garbage collection should be run on drop + pub(self) collect_garbage: bool, +} + +/// Weak reference to Lua instance. +/// +/// This can used to prevent circular references between Lua and Rust objects. +#[derive(Clone)] +pub struct WeakLua(XWeak>); + +pub(crate) struct LuaGuard(ArcReentrantMutexGuard); + +/// Tuning parameters for the incremental GC collector. +/// +/// More information can be found in the Lua [documentation]. +/// +/// [documentation]: https://www.lua.org/manual/5.5/manual.html#2.5.1 +#[non_exhaustive] +#[derive(Clone, Copy, Debug, Default)] +pub struct GcIncParams { + /// Pause between successive GC cycles, expressed as a percentage of live memory. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub pause: Option, + + /// Target heap size as a percentage of live data, controlling how aggressively + /// the GC reclaims memory (`LUA_GCSETGOAL`). + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub goal: Option, + + /// GC work performed per unit of memory allocated. + pub step_multiplier: Option, + + /// Granularity of each GC step (see Lua reference for details). + #[cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))))] + pub step_size: Option, +} + +impl GcIncParams { + /// Sets the `pause` parameter. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn pause(mut self, v: c_int) -> Self { + self.pause = Some(v); + self + } + + /// Sets the `goal` parameter. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn goal(mut self, v: c_int) -> Self { + self.goal = Some(v); + self + } + + /// Sets the `step_multiplier` parameter. + pub fn step_multiplier(mut self, v: c_int) -> Self { + self.step_multiplier = Some(v); + self + } + + /// Sets the `step_size` parameter. + #[cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))))] + pub fn step_size(mut self, v: c_int) -> Self { + self.step_size = Some(v); + self + } +} + +/// Tuning parameters for the generational GC collector (Lua 5.4+). +/// +/// More information can be found in the Lua [documentation]. +/// +/// [documentation]: https://www.lua.org/manual/5.5/manual.html#2.5.2 +#[cfg(any(feature = "lua55", feature = "lua54"))] +#[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] +#[non_exhaustive] +#[derive(Clone, Copy, Debug, Default)] +pub struct GcGenParams { + /// Frequency of minor (young-generation) collection steps. + pub minor_multiplier: Option, + + /// Threshold controlling how large the young generation can grow before triggering + /// a shift from minor to major collection. + pub minor_to_major: Option, + + /// Threshold controlling how much the major collection must shrink the heap before + /// switching back to minor (young-generation) collection. + #[cfg(feature = "lua55")] + #[cfg_attr(docsrs, doc(cfg(feature = "lua55")))] + pub major_to_minor: Option, +} + +#[cfg(any(feature = "lua55", feature = "lua54"))] +impl GcGenParams { + /// Sets the `minor_multiplier` parameter. + pub fn minor_multiplier(mut self, v: c_int) -> Self { + self.minor_multiplier = Some(v); + self + } + + /// Sets the `minor_to_major` threshold. + pub fn minor_to_major(mut self, v: c_int) -> Self { + self.minor_to_major = Some(v); + self + } + + /// Sets the `major_to_minor` parameter. + #[cfg(feature = "lua55")] + #[cfg_attr(docsrs, doc(cfg(feature = "lua55")))] + pub fn major_to_minor(mut self, v: c_int) -> Self { + self.major_to_minor = Some(v); + self + } +} + +/// Lua garbage collector (GC) operating mode. +/// +/// Use [`Lua::gc_set_mode`] to switch the collector mode and/or tune its parameters. +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum GcMode { + /// Incremental mark-and-sweep + Incremental(GcIncParams), + + /// Generational + #[cfg(any(feature = "lua55", feature = "lua54"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] + Generational(GcGenParams), +} + +/// Controls Lua interpreter behavior such as Rust panics handling. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct LuaOptions { + /// Catch Rust panics when using [`pcall`]/[`xpcall`]. + /// + /// If disabled, wraps these functions and automatically resumes panic if found. + /// Also in Lua 5.1 adds ability to provide arguments to [`xpcall`] similar to Lua >= 5.2. + /// + /// If enabled, keeps [`pcall`]/[`xpcall`] unmodified. + /// Panics are still automatically resumed if returned to the Rust side. + /// + /// Default: **true** + /// + /// [`pcall`]: https://www.lua.org/manual/5.4/manual.html#pdf-pcall + /// [`xpcall`]: https://www.lua.org/manual/5.4/manual.html#pdf-xpcall + pub catch_rust_panics: bool, + + /// Max size of thread (coroutine) object pool used to execute asynchronous functions. + /// + /// Default: **0** (disabled) + /// + /// [`lua_resetthread`]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub thread_pool_size: usize, +} + +impl Default for LuaOptions { + fn default() -> Self { + const { LuaOptions::new() } + } +} + +impl LuaOptions { + /// Returns a new instance of `LuaOptions` with default parameters. + pub const fn new() -> Self { + LuaOptions { + catch_rust_panics: true, + #[cfg(feature = "async")] + thread_pool_size: 0, + } + } + + /// Sets [`catch_rust_panics`] option. + /// + /// [`catch_rust_panics`]: #structfield.catch_rust_panics + #[must_use] + pub const fn catch_rust_panics(mut self, enabled: bool) -> Self { + self.catch_rust_panics = enabled; + self + } + + /// Sets [`thread_pool_size`] option. + /// + /// [`thread_pool_size`]: #structfield.thread_pool_size + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + #[must_use] + pub const fn thread_pool_size(mut self, size: usize) -> Self { + self.thread_pool_size = size; + self + } +} + +impl Drop for Lua { + fn drop(&mut self) { + if self.collect_garbage { + let _ = self.gc_collect(); + } + } +} + +impl Clone for Lua { + #[inline] + fn clone(&self) -> Self { + Lua { + raw: XRc::clone(&self.raw), + collect_garbage: false, + } + } +} + +impl fmt::Debug for Lua { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Lua({:p})", self.lock().state()) + } +} + +impl Default for Lua { + #[inline] + fn default() -> Self { + Lua::new() + } +} + +impl Lua { + /// Creates a new Lua state and loads the **safe** subset of the standard libraries. + /// + /// # Safety + /// The created Lua state will have _some_ safety guarantees and will not allow to load unsafe + /// standard libraries or C modules. + /// + /// See [`StdLib`] documentation for a list of unsafe modules that cannot be loaded. + pub fn new() -> Lua { + mlua_expect!( + Self::new_with(StdLib::ALL_SAFE, LuaOptions::default()), + "Cannot create a Lua state" + ) + } + + /// Creates a new Lua state and loads all the standard libraries. + /// + /// # Safety + /// The created Lua state will not have safety guarantees and will allow to load C modules. + pub unsafe fn unsafe_new() -> Lua { + Self::unsafe_new_with(StdLib::ALL, LuaOptions::default()) + } + + /// Creates a new Lua state and loads the specified safe subset of the standard libraries. + /// + /// Use the [`StdLib`] flags to specify the libraries you want to load. + /// + /// # Safety + /// The created Lua state will have _some_ safety guarantees and will not allow to load unsafe + /// standard libraries or C modules. + /// + /// See [`StdLib`] documentation for a list of unsafe modules that cannot be loaded. + pub fn new_with(libs: StdLib, options: LuaOptions) -> Result { + #[cfg(not(feature = "luau"))] + if libs.contains(StdLib::DEBUG) { + return Err(Error::SafetyError( + "The unsafe `debug` module can't be loaded using safe `new_with`".to_string(), + )); + } + #[cfg(feature = "luajit")] + if libs.contains(StdLib::FFI) { + return Err(Error::SafetyError( + "The unsafe `ffi` module can't be loaded using safe `new_with`".to_string(), + )); + } + + let lua = unsafe { Self::inner_new(libs, options) }; + + #[cfg(not(feature = "luau"))] + if libs.contains(StdLib::PACKAGE) { + mlua_expect!(lua.disable_c_modules(), "Error disabling C modules"); + } + lua.lock().mark_safe(); + + Ok(lua) + } + + /// Creates a new Lua state and loads the specified subset of the standard libraries. + /// + /// Use the [`StdLib`] flags to specify the libraries you want to load. + /// + /// # Safety + /// The created Lua state will not have safety guarantees and allow to load C modules. + pub unsafe fn unsafe_new_with(libs: StdLib, options: LuaOptions) -> Lua { + // Workaround to avoid stripping a few unused Lua symbols that could be imported + // by C modules in unsafe mode + let mut _symbols: Vec<*const extern "C-unwind" fn()> = + vec![ffi::lua_isuserdata as _, ffi::lua_tocfunction as _]; + + #[cfg(not(feature = "luau"))] + _symbols.extend_from_slice(&[ + ffi::lua_atpanic as _, + ffi::luaL_loadstring as _, + ffi::luaL_openlibs as _, + ]); + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + { + _symbols.push(ffi::lua_getglobal as _); + _symbols.push(ffi::lua_setglobal as _); + _symbols.push(ffi::luaL_setfuncs as _); + } + + Self::inner_new(libs, options) + } + + /// Creates a new Lua state with required `libs` and `options` + unsafe fn inner_new(libs: StdLib, options: LuaOptions) -> Lua { + let lua = Lua { + raw: RawLua::new(libs, &options), + collect_garbage: true, + }; + + #[cfg(feature = "luau")] + mlua_expect!(lua.configure_luau(), "Error configuring Luau"); + + lua + } + + /// Returns or constructs Lua instance from a raw state. + /// + /// Once initialized, the returned Lua instance is cached in the registry and can be retrieved + /// by calling this function again. + /// + /// # Safety + /// The `Lua` must outlive the chosen lifetime `'a`. + #[inline] + pub unsafe fn get_or_init_from_ptr<'a>(state: *mut ffi::lua_State) -> &'a Lua { + debug_assert!(!state.is_null(), "Lua state is null"); + match ExtraData::get(state) { + extra if !extra.is_null() => (*extra).lua(), + _ => { + // The `owned` flag is set to `false` as we don't own the Lua state. + RawLua::init_from_ptr(state, false); + (*ExtraData::get(state)).lua() + } + } + } + + /// Calls provided function passing a raw lua state. + /// + /// The arguments will be pushed onto the stack before calling the function. + /// + /// This method ensures that the Lua instance is locked while the function is called + /// and restores Lua stack after the function returns. + /// + /// # Example + /// ``` + /// # use mlua::{Lua, Result}; + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// let n: i32 = unsafe { + /// let nums = (3, 4, 5); + /// lua.exec_raw(nums, |state| { + /// let n = ffi::lua_gettop(state); + /// let mut sum = 0; + /// for i in 1..=n { + /// sum += ffi::lua_tointeger(state, i); + /// } + /// ffi::lua_pop(state, n); + /// ffi::lua_pushinteger(state, sum); + /// }) + /// }?; + /// assert_eq!(n, 12); + /// # Ok(()) + /// # } + /// ``` + #[allow(clippy::missing_safety_doc)] + pub unsafe fn exec_raw( + &self, + args: impl IntoLuaMulti, + f: impl FnOnce(*mut ffi::lua_State), + ) -> Result { + let lua = self.lock(); + let state = lua.state(); + let _sg = StackGuard::new(state); + let stack_start = ffi::lua_gettop(state); + let nargs = args.push_into_stack_multi(&lua)?; + check_stack(state, 3)?; + protect_lua_closure::<_, ()>(state, nargs, ffi::LUA_MULTRET, f)?; + let nresults = ffi::lua_gettop(state) - stack_start; + R::from_stack_multi(nresults, &lua) + } + + /// Calls provided function passing a reference to the [`RawLua`] handle. + /// + /// Provided [`RawLua`] handle can be used to manually pushing/popping values to/from the stack. + /// + /// # Example + /// ``` + /// # use mlua::{Lua, Result, FromLua, IntoLua, IntoLuaMulti}; + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// let n: i32 = { + /// let nums = (3, 4, 5); + /// lua.exec_raw_lua(|rawlua| unsafe { + /// nums.push_into_stack_multi(rawlua)?; + /// let mut sum = 0; + /// for _ in 0..3 { + /// sum += rawlua.pop::()?; + /// } + /// Result::Ok(sum) + /// }) + /// }?; + /// assert_eq!(n, 12); + /// # Ok(()) + /// # } + /// ``` + #[doc(hidden)] + pub fn exec_raw_lua(&self, f: impl FnOnce(&RawLua) -> R) -> R { + let lua = self.lock(); + f(&lua) + } + + /// Loads the specified subset of the standard libraries into an existing Lua state. + /// + /// Use the [`StdLib`] flags to specify the libraries you want to load. + pub fn load_std_libs(&self, libs: StdLib) -> Result<()> { + unsafe { self.lock().load_std_libs(libs) } + } + + /// Registers module into an existing Lua state using the specified value. + /// + /// After registration, the given value will always be immediately returned when the + /// given module is [required]. + /// + /// [required]: https://www.lua.org/manual/5.4/manual.html#pdf-require + pub fn register_module(&self, modname: &str, value: impl IntoLua) -> Result<()> { + #[cfg(not(feature = "luau"))] + const LOADED_MODULES_KEY: *const c_char = ffi::LUA_LOADED_TABLE; + #[cfg(feature = "luau")] + const LOADED_MODULES_KEY: *const c_char = ffi::LUA_REGISTERED_MODULES_TABLE; + + if cfg!(feature = "luau") && !modname.starts_with('@') { + return Err(Error::runtime("module name must begin with '@'")); + } + #[cfg(feature = "luau")] + let modname = modname.to_ascii_lowercase(); + unsafe { + self.exec_raw::<()>(value, |state| { + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, LOADED_MODULES_KEY); + ffi::lua_pushlstring(state, modname.as_ptr() as *const c_char, modname.len() as _); + ffi::lua_pushvalue(state, -3); + ffi::lua_rawset(state, -3); + }) + } + } + + /// Preloads module into an existing Lua state using the specified loader function. + /// + /// When the module is required, the loader function will be called with module name as the + /// first argument. + /// + /// This is similar to setting the [`package.preload[modname]`] field. + /// + /// [`package.preload[modname]`]: + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn preload_module(&self, modname: &str, func: Function) -> Result<()> { + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + let preload = unsafe { + self.exec_raw::>((), |state| { + ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_PRELOAD_TABLE); + })? + }; + #[cfg(any(feature = "lua51", feature = "luajit"))] + let preload = unsafe { + self.exec_raw::>((), |state| { + if ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_LOADED_TABLE) != ffi::LUA_TNIL { + ffi::luaL_getsubtable(state, -1, ffi::LUA_LOADLIBNAME); + ffi::luaL_getsubtable(state, -1, cstr!("preload")); + ffi::lua_rotate(state, 1, 1); + } + })? + }; + if let Some(preload) = preload { + preload.raw_set(modname, func)?; + } + Ok(()) + } + + /// Unloads module `modname`. + /// + /// This method does not support unloading binary Lua modules since they are internally cached + /// and can be unloaded only by closing Lua state. + /// + /// This is similar to calling [`Lua::register_module`] with `Nil` value. + /// + /// [`package.loaded`]: https://www.lua.org/manual/5.4/manual.html#pdf-package.loaded + pub fn unload_module(&self, modname: &str) -> Result<()> { + self.register_module(modname, Nil) + } + + // Executes module entrypoint function, which returns only one Value. + // The returned value then pushed onto the stack. + #[doc(hidden)] + #[cfg(not(tarpaulin_include))] + pub unsafe fn entrypoint(state: *mut ffi::lua_State, func: F) -> c_int + where + F: FnOnce(&Lua, A) -> Result, + A: FromLuaMulti, + R: IntoLua, + { + // Make sure that Lua is initialized + let _ = Self::get_or_init_from_ptr(state); + + callback_error_ext(state, ptr::null_mut(), true, move |extra, nargs| { + let rawlua = (*extra).raw_lua(); + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack(rawlua)?; + Ok(1) + }) + } + + // A simple module entrypoint without arguments + #[doc(hidden)] + #[cfg(not(tarpaulin_include))] + pub unsafe fn entrypoint1(state: *mut ffi::lua_State, func: F) -> c_int + where + F: FnOnce(&Lua) -> Result, + R: IntoLua, + { + Self::entrypoint(state, move |lua, _: ()| func(lua)) + } + + /// Skips memory checks for some operations. + #[doc(hidden)] + #[cfg(feature = "module")] + pub fn skip_memory_check(&self, skip: bool) { + let lua = self.lock(); + unsafe { (*lua.extra.get()).skip_memory_check = skip }; + } + + /// Enables (or disables) sandbox mode on this Lua instance. + /// + /// This method, in particular: + /// - Set all libraries to read-only + /// - Set all builtin metatables to read-only + /// - Set globals to read-only (and activates safeenv) + /// - Setup local environment table that performs writes locally and proxies reads to the global + /// environment. + /// - Allow only `count` mode in `collectgarbage` function. + /// + /// # Examples + /// + /// ``` + /// # use mlua::{Lua, Result}; + /// # #[cfg(feature = "luau")] + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// + /// lua.sandbox(true)?; + /// lua.load("var = 123").exec()?; + /// assert_eq!(lua.globals().get::("var")?, 123); + /// + /// // Restore the global environment (clear changes made in sandbox) + /// lua.sandbox(false)?; + /// assert_eq!(lua.globals().get::>("var")?, None); + /// # Ok(()) + /// # } + /// + /// # #[cfg(not(feature = "luau"))] + /// # fn main() {} + /// ``` + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn sandbox(&self, enabled: bool) -> Result<()> { + let lua = self.lock(); + unsafe { + if (*lua.extra.get()).sandboxed != enabled { + let state = lua.main_state(); + check_stack(state, 3)?; + protect_lua!(state, 0, 0, |state| { + if enabled { + ffi::luaL_sandbox(state, 1); + ffi::luaL_sandboxthread(state); + } else { + // Restore original `LUA_GLOBALSINDEX` + ffi::lua_xpush(lua.ref_thread(), state, ffi::LUA_GLOBALSINDEX); + ffi::lua_replace(state, ffi::LUA_GLOBALSINDEX); + ffi::luaL_sandbox(state, 0); + } + })?; + (*lua.extra.get()).sandboxed = enabled; + } + Ok(()) + } + } + + /// Sets or replaces a global hook function that will periodically be called as Lua code + /// executes. + /// + /// All new threads created (by mlua) after this call will use the global hook function. + /// + /// For more information see [`Lua::set_hook`]. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn set_global_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> + where + F: Fn(&Lua, &Debug) -> Result + MaybeSend + 'static, + { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).hook_triggers = triggers; + (*lua.extra.get()).hook_callback = Some(XRc::new(callback)); + lua.set_thread_hook(lua.state(), HookKind::Global) + } + } + + /// Sets a hook function that will periodically be called as Lua code executes. + /// + /// When exactly the hook function is called depends on the contents of the `triggers` + /// parameter, see [`HookTriggers`] for more details. + /// + /// The provided hook function can error, and this error will be propagated through the Lua code + /// that was executing at the time the hook was triggered. This can be used to implement a + /// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and + /// erroring once an instruction limit has been reached. + /// + /// This method sets a hook function for the *current* thread of this Lua instance. + /// If you want to set a hook function for another thread (coroutine), use + /// [`Thread::set_hook`] instead. + /// + /// # Example + /// + /// Shows each line number of code being executed by the Lua interpreter. + /// + /// ``` + /// # use mlua::{Lua, HookTriggers, Result, VmState}; + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| { + /// println!("line {:?}", debug.current_line()); + /// Ok(VmState::Continue) + /// }); + /// + /// lua.load(r#" + /// local x = 2 + 3 + /// local y = x * 63 + /// local z = string.len(x..", "..y) + /// "#).exec() + /// # } + /// ``` + /// + /// [`HookTriggers.every_nth_instruction`]: crate::HookTriggers::every_nth_instruction + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn set_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> + where + F: Fn(&Lua, &Debug) -> Result + MaybeSend + 'static, + { + let lua = self.lock(); + unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, XRc::new(callback))) } + } + + /// Removes a global hook previously set by [`Lua::set_global_hook`]. + /// + /// This function has no effect if a hook was not previously set. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn remove_global_hook(&self) { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).hook_callback = None; + (*lua.extra.get()).hook_triggers = HookTriggers::default(); + } + } + + /// Removes any hook from the current thread. + /// + /// This function has no effect if a hook was not previously set. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn remove_hook(&self) { + let lua = self.lock(); + unsafe { + ffi::lua_sethook(lua.state(), None, 0, 0); + } + } + + /// Sets an interrupt function that will periodically be called by Luau VM. + /// + /// Any Luau code is guaranteed to call this handler "eventually" + /// (in practice this can happen at any function call or at any loop iteration). + /// This is similar to `Lua::set_hook` but in more simplified form. + /// + /// The provided interrupt function can error, and this error will be propagated through + /// the Luau code that was executing at the time the interrupt was triggered. + /// Also this can be used to implement continuous execution limits by instructing Luau VM to + /// yield by returning [`VmState::Yield`]. The yield will happen only at yieldable points + /// of execution (not across metamethod/C-call boundaries). + /// + /// # Example + /// + /// Periodically yield Luau VM to suspend execution. + /// + /// ``` + /// # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; + /// # use mlua::{Lua, Result, ThreadStatus, VmState}; + /// # #[cfg(feature = "luau")] + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// let count = Arc::new(AtomicU64::new(0)); + /// lua.set_interrupt(move |_| { + /// if count.fetch_add(1, Ordering::Relaxed) % 2 == 0 { + /// return Ok(VmState::Yield); + /// } + /// Ok(VmState::Continue) + /// }); + /// + /// let co = lua.create_thread( + /// lua.load(r#" + /// local b = 0 + /// for _, x in ipairs({1, 2, 3}) do b += x end + /// "#) + /// .into_function()?, + /// )?; + /// while co.status() == ThreadStatus::Resumable { + /// co.resume::<()>(())?; + /// } + /// # Ok(()) + /// # } + /// + /// # #[cfg(not(feature = "luau"))] + /// # fn main() {} + /// ``` + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_interrupt(&self, callback: F) + where + F: Fn(&Lua) -> Result + MaybeSend + 'static, + { + unsafe extern "C-unwind" fn interrupt_proc(state: *mut ffi::lua_State, gc: c_int) { + if gc >= 0 { + // We don't support GC interrupts since they cannot survive Lua exceptions + return; + } + let result = callback_error_ext(state, ptr::null_mut(), false, move |extra, _| { + let interrupt_cb = (*extra).interrupt_callback.clone(); + let interrupt_cb = mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc"); + if XRc::strong_count(&interrupt_cb) > 2 { + return Ok(VmState::Continue); // Don't allow recursion + } + interrupt_cb((*extra).lua()) + }); + match result { + VmState::Continue => {} + VmState::Yield => { + // We can yield only at yieldable points, otherwise ignore and continue + if ffi::lua_isyieldable(state) != 0 { + ffi::lua_yield(state, 0); + } + } + } + } + + // Set interrupt callback + let lua = self.lock(); + unsafe { + (*lua.extra.get()).interrupt_callback = Some(XRc::new(callback)); + (*ffi::lua_callbacks(lua.main_state())).interrupt = Some(interrupt_proc); + } + } + + /// Removes any interrupt function previously set by `set_interrupt`. + /// + /// This function has no effect if an 'interrupt' was not previously set. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn remove_interrupt(&self) { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).interrupt_callback = None; + (*ffi::lua_callbacks(lua.main_state())).interrupt = None; + } + } + + /// Sets a thread creation callback that will be called when a thread is created. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_thread_creation_callback(&self, callback: F) + where + F: Fn(&Lua, Thread) -> Result<()> + MaybeSend + 'static, + { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).thread_creation_callback = Some(XRc::new(callback)); + (*ffi::lua_callbacks(lua.main_state())).userthread = Some(Self::userthread_proc); + } + } + + /// Sets a thread collection callback that will be called when a thread is destroyed. + /// + /// Luau GC does not support exceptions during collection, so the callback must be + /// non-panicking. If the callback panics, the program will be aborted. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_thread_collection_callback(&self, callback: F) + where + F: Fn(crate::LightUserData) + MaybeSend + 'static, + { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).thread_collection_callback = Some(XRc::new(callback)); + (*ffi::lua_callbacks(lua.main_state())).userthread = Some(Self::userthread_proc); + } + } + + #[cfg(feature = "luau")] + unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, child: *mut ffi::lua_State) { + let extra = ExtraData::get(child); + if !parent.is_null() { + // Thread is created + let callback = match (*extra).thread_creation_callback { + Some(ref cb) => cb.clone(), + None => return, + }; + if XRc::strong_count(&callback) > 2 { + return; // Don't allow recursion + } + ffi::lua_pushthread(child); + ffi::lua_xmove(child, (*extra).ref_thread, 1); + let value = Thread((*extra).raw_lua().pop_ref_thread(), child); + callback_error_ext(parent, extra, false, move |extra, _| { + callback((*extra).lua(), value) + }) + } else { + // Thread is about to be collected + let callback = match (*extra).thread_collection_callback { + Some(ref cb) => cb.clone(), + None => return, + }; + + // We need to wrap the callback call in non-unwind function as it's not safe to unwind when + // Luau GC is running. + // This will trigger `abort()` if the callback panics. + unsafe extern "C" fn run_callback( + callback: *const crate::types::ThreadCollectionCallback, + value: *mut ffi::lua_State, + ) { + (*callback)(crate::LightUserData(value as _)); + } + + (*extra).running_gc = true; + run_callback(&callback, child); + (*extra).running_gc = false; + } + } + + /// Removes any thread creation or collection callbacks previously set by + /// [`Lua::set_thread_creation_callback`] or [`Lua::set_thread_collection_callback`]. + /// + /// This function has no effect if a thread callbacks were not previously set. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn remove_thread_callbacks(&self) { + let lua = self.lock(); + unsafe { + let extra = lua.extra.get(); + (*extra).thread_creation_callback = None; + (*extra).thread_collection_callback = None; + (*ffi::lua_callbacks(lua.main_state())).userthread = None; + } + } + + /// Sets the warning function to be used by Lua to emit warnings. + #[cfg(any(feature = "lua55", feature = "lua54"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] + pub fn set_warning_function(&self, callback: F) + where + F: Fn(&Lua, &str, bool) -> Result<()> + MaybeSend + 'static, + { + use std::ffi::CStr; + use std::os::raw::{c_char, c_void}; + + unsafe extern "C-unwind" fn warn_proc(ud: *mut c_void, msg: *const c_char, tocont: c_int) { + let extra = ud as *mut ExtraData; + callback_error_ext((*extra).raw_lua().state(), extra, false, |extra, _| { + let warn_callback = (*extra).warn_callback.clone(); + let warn_callback = mlua_expect!(warn_callback, "no warning callback set in warn_proc"); + if XRc::strong_count(&warn_callback) > 2 { + return Ok(()); + } + let msg = String::from_utf8_lossy(CStr::from_ptr(msg).to_bytes()); + warn_callback((*extra).lua(), &msg, tocont != 0) + }); + } + + let lua = self.lock(); + unsafe { + (*lua.extra.get()).warn_callback = Some(XRc::new(callback)); + ffi::lua_setwarnf(lua.state(), Some(warn_proc), lua.extra.get() as *mut c_void); + } + } + + /// Removes warning function previously set by `set_warning_function`. + /// + /// This function has no effect if a warning function was not previously set. + #[cfg(any(feature = "lua55", feature = "lua54"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] + pub fn remove_warning_function(&self) { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).warn_callback = None; + ffi::lua_setwarnf(lua.state(), None, ptr::null_mut()); + } + } + + /// Emits a warning with the given message. + /// + /// A message in a call with `incomplete` set to `true` should be continued in + /// another call to this function. + #[cfg(any(feature = "lua55", feature = "lua54"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] + pub fn warning(&self, msg: impl AsRef, incomplete: bool) { + let msg = msg.as_ref(); + let mut bytes = vec![0; msg.len() + 1]; + bytes[..msg.len()].copy_from_slice(msg.as_bytes()); + let real_len = bytes.iter().position(|&c| c == 0).unwrap(); + bytes.truncate(real_len); + let lua = self.lock(); + unsafe { + ffi::lua_warning(lua.state(), bytes.as_ptr() as *const _, incomplete as c_int); + } + } + + /// Gets information about the interpreter runtime stack at the given level. + /// + /// This function calls callback `f`, passing the [`struct@Debug`] structure that can be used to + /// get information about the function executing at a given level. + /// Level `0` is the current running function, whereas level `n+1` is the function that has + /// called level `n` (except for tail calls, which do not count in the stack). + pub fn inspect_stack(&self, level: usize, f: impl FnOnce(&Debug) -> R) -> Option { + let lua = self.lock(); + unsafe { + let mut ar = mem::zeroed::(); + let level = level as c_int; + #[cfg(not(feature = "luau"))] + if ffi::lua_getstack(lua.state(), level, &mut ar) == 0 { + return None; + } + #[cfg(feature = "luau")] + if ffi::lua_getinfo(lua.state(), level, cstr!(""), &mut ar) == 0 { + return None; + } + + Some(f(&Debug::new(&lua, level, &mut ar))) + } + } + + /// Creates a traceback of the call stack at the given level. + /// + /// The `msg` parameter, if provided, is added at the beginning of the traceback. + /// The `level` parameter works the same way as in [`Lua::inspect_stack`]. + pub fn traceback(&self, msg: Option<&str>, level: usize) -> Result { + let lua = self.lock(); + unsafe { + check_stack(lua.state(), 3)?; + protect_lua!(lua.state(), 0, 1, |state| { + let msg = match msg { + Some(s) => ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()), + None => ptr::null(), + }; + // `protect_lua` adds it's own call frame, so we need to increase level by 1 + ffi::luaL_traceback(state, state, msg, (level + 1) as c_int); + })?; + Ok(LuaString(lua.pop_ref())) + } + } + + /// Returns the amount of memory (in bytes) currently used inside this Lua state. + pub fn used_memory(&self) -> usize { + let lua = self.lock(); + let state = lua.main_state(); + unsafe { + match MemoryState::get(state) { + mem_state if !mem_state.is_null() => (*mem_state).used_memory(), + _ => { + // Get data from the Lua GC + let used_kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0); + let used_kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0); + (used_kbytes as usize) * 1024 + (used_kbytes_rem as usize) + } + } + } + } + + /// Sets a memory limit (in bytes) on this Lua state. + /// + /// Once an allocation occurs that would pass this memory limit, a `Error::MemoryError` is + /// generated instead. + /// Returns previous limit (zero means no limit). + /// + /// Does not work in module mode where Lua state is managed externally. + pub fn set_memory_limit(&self, limit: usize) -> Result { + let lua = self.lock(); + unsafe { + match MemoryState::get(lua.state()) { + mem_state if !mem_state.is_null() => Ok((*mem_state).set_memory_limit(limit)), + _ => Err(Error::MemoryControlNotAvailable), + } + } + } + + /// Returns `true` if the garbage collector is currently running automatically. + #[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luau" + ))] + pub fn gc_is_running(&self) -> bool { + let lua = self.lock(); + unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCISRUNNING, 0) != 0 } + } + + /// Stops the Lua GC from running. + pub fn gc_stop(&self) { + let lua = self.lock(); + unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCSTOP, 0) }; + } + + /// Restarts the Lua GC if it is not running. + pub fn gc_restart(&self) { + let lua = self.lock(); + unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCRESTART, 0) }; + } + + /// Performs a full garbage-collection cycle. + /// + /// It may be necessary to call this function twice to collect all currently unreachable + /// objects. Once to finish the current gc cycle, and once to start and finish the next cycle. + pub fn gc_collect(&self) -> Result<()> { + let lua = self.lock(); + let state = lua.main_state(); + unsafe { + check_stack(state, 2)?; + protect_lua!(state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0)) + } + } + + /// Performs a basic step of garbage collection. + /// + /// In incremental mode, a basic step corresponds to the current step size. In generational + /// mode, a basic step performs a full minor collection or an incremental step, if the collector + /// has scheduled one. + /// + /// In incremental mode, returns `true` if this step has finished a collection cycle. + /// In generational mode, returns `true` if the step finished a major collection. + pub fn gc_step(&self) -> Result { + let lua = self.lock(); + let state = lua.main_state(); + unsafe { + check_stack(state, 3)?; + protect_lua!(state, 0, 0, |state| { + ffi::lua_gc(state, ffi::LUA_GCSTEP, 0) != 0 + }) + } + } + + /// Switches the GC to the given mode with the provided parameters. + /// + /// Returns the previous [`GcMode`]. The returned value's parameter fields are always + /// `None` because Lua's C API does not provide a way to read back current parameter values + /// without changing them. + /// + /// # Examples + /// + /// Switch to generational mode (Lua 5.4+): + /// ```ignore + /// let prev = lua.gc_set_mode(GcMode::Generational(GcGenParams::default())); + /// ``` + /// + /// Switch to incremental mode with custom parameters: + /// ```ignore + /// lua.gc_set_mode(GcMode::Incremental( + /// GcIncParams::default().pause(200).step_multiplier(100) + /// )); + /// ``` + pub fn gc_set_mode(&self, mode: GcMode) -> GcMode { + let lua = self.lock(); + let state = lua.main_state(); + + match mode { + #[cfg(feature = "lua55")] + GcMode::Incremental(params) => unsafe { + if let Some(v) = params.pause { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPPAUSE, v); + } + if let Some(v) = params.step_multiplier { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPSTEPMUL, v); + } + if let Some(v) = params.step_size { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPSTEPSIZE, v); + } + match ffi::lua_gc(state, ffi::LUA_GCINC) { + ffi::LUA_GCINC => GcMode::Incremental(GcIncParams::default()), + ffi::LUA_GCGEN => GcMode::Generational(GcGenParams::default()), + _ => unreachable!(), + } + }, + #[cfg(feature = "lua54")] + GcMode::Incremental(params) => unsafe { + let pause = params.pause.unwrap_or(0); + let step_mul = params.step_multiplier.unwrap_or(0); + let step_size = params.step_size.unwrap_or(0); + match ffi::lua_gc(state, ffi::LUA_GCINC, pause, step_mul, step_size) { + ffi::LUA_GCINC => GcMode::Incremental(GcIncParams::default()), + ffi::LUA_GCGEN => GcMode::Generational(GcGenParams::default()), + _ => unreachable!(), + } + }, + #[cfg(any(feature = "lua53", feature = "lua52", feature = "lua51", feature = "luajit"))] + GcMode::Incremental(params) => unsafe { + if let Some(v) = params.pause { + ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, v); + } + if let Some(v) = params.step_multiplier { + ffi::lua_gc(state, ffi::LUA_GCSETSTEPMUL, v); + } + GcMode::Incremental(GcIncParams::default()) + }, + #[cfg(feature = "luau")] + GcMode::Incremental(params) => unsafe { + if let Some(v) = params.goal { + ffi::lua_gc(state, ffi::LUA_GCSETGOAL, v); + } + if let Some(v) = params.step_multiplier { + ffi::lua_gc(state, ffi::LUA_GCSETSTEPMUL, v); + } + if let Some(v) = params.step_size { + ffi::lua_gc(state, ffi::LUA_GCSETSTEPSIZE, v); + } + GcMode::Incremental(GcIncParams::default()) + }, + + #[cfg(feature = "lua55")] + GcMode::Generational(params) => unsafe { + if let Some(v) = params.minor_multiplier { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPMINORMUL, v); + } + if let Some(v) = params.minor_to_major { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPMINORMAJOR, v); + } + if let Some(v) = params.major_to_minor { + ffi::lua_gc(state, ffi::LUA_GCPARAM, ffi::LUA_GCPMAJORMINOR, v); + } + match ffi::lua_gc(state, ffi::LUA_GCGEN) { + ffi::LUA_GCGEN => GcMode::Generational(GcGenParams::default()), + ffi::LUA_GCINC => GcMode::Incremental(GcIncParams::default()), + _ => unreachable!(), + } + }, + #[cfg(feature = "lua54")] + GcMode::Generational(params) => unsafe { + let minor = params.minor_multiplier.unwrap_or(0); + let minor_to_major = params.minor_to_major.unwrap_or(0); + match ffi::lua_gc(state, ffi::LUA_GCGEN, minor, minor_to_major) { + ffi::LUA_GCGEN => GcMode::Generational(GcGenParams::default()), + ffi::LUA_GCINC => GcMode::Incremental(GcIncParams::default()), + _ => unreachable!(), + } + }, + } + } + + /// Sets a default Luau compiler (with custom options). + /// + /// This compiler will be used by default to load all Lua chunks + /// including via `require` function. + /// + /// See [`Compiler`] for details and possible options. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_compiler(&self, compiler: Compiler) { + let lua = self.lock(); + unsafe { (*lua.extra.get()).compiler = Some(compiler) }; + } + + /// Toggles JIT compilation mode for new chunks of code. + /// + /// By default JIT is enabled. Changing this option does not have any effect on + /// already loaded functions. + #[cfg(any(feature = "luau-jit", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau-jit")))] + pub fn enable_jit(&self, enable: bool) { + let lua = self.lock(); + unsafe { (*lua.extra.get()).enable_jit = enable }; + } + + /// Sets Luau feature flag (global setting). + /// + /// See https://github.com/luau-lang/luau/blob/master/CONTRIBUTING.md#feature-flags for details. + #[cfg(feature = "luau")] + #[doc(hidden)] + #[allow(clippy::result_unit_err)] + pub fn set_fflag(name: &str, enabled: bool) -> StdResult<(), ()> { + if let Ok(name) = std::ffi::CString::new(name) + && unsafe { ffi::luau_setfflag(name.as_ptr(), enabled as c_int) != 0 } + { + return Ok(()); + } + Err(()) + } + + /// Returns Lua source code as a `Chunk` builder type. + /// + /// In order to actually compile or run the resulting code, you must call [`Chunk::exec`] or + /// similar on the returned builder. Code is not even parsed until one of these methods is + /// called. + /// + /// [`Chunk::exec`]: crate::Chunk::exec + #[track_caller] + pub fn load<'a>(&self, chunk: impl AsChunk + 'a) -> Chunk<'a> { + self.load_with_location(chunk, Location::caller()) + } + + pub(crate) fn load_with_location<'a>( + &self, + chunk: impl AsChunk + 'a, + location: &'static Location<'static>, + ) -> Chunk<'a> { + Chunk { + lua: self.weak(), + name: chunk + .name() + .unwrap_or_else(|| format!("@{}:{}", location.file(), location.line())), + env: chunk.environment(self), + mode: chunk.mode(), + source: chunk.source(), + #[cfg(feature = "luau")] + compiler: unsafe { (*self.lock().extra.get()).compiler.clone() }, + } + } + + /// Creates and returns an interned Lua string. + /// + /// Lua strings can be arbitrary `[u8]` data including embedded nulls, so in addition to `&str` + /// and `&String`, you can also pass plain `&[u8]` here. + #[inline] + pub fn create_string(&self, s: impl AsRef<[u8]>) -> Result { + unsafe { self.lock().create_string(s.as_ref()) } + } + + /// Creates and returns an external Lua string. + /// + /// External string is a string where the memory is managed by Rust code, and Lua only holds a + /// reference to it. This can be used to avoid copying large strings into Lua memory. + #[cfg(feature = "lua55")] + #[cfg_attr(docsrs, doc(cfg(feature = "lua55")))] + #[inline] + pub fn create_external_string(&self, s: impl Into>) -> Result { + unsafe { self.lock().create_external_string(s.into()) } + } + + /// Creates and returns a Luau [buffer] object from a byte slice of data. + /// + /// [buffer]: https://luau.org/library#buffer-library + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn create_buffer(&self, data: impl AsRef<[u8]>) -> Result { + let lua = self.lock(); + let data = data.as_ref(); + unsafe { + let (ptr, buffer) = lua.create_buffer_with_capacity(data.len())?; + ptr.copy_from_nonoverlapping(data.as_ptr(), data.len()); + Ok(buffer) + } + } + + /// Creates and returns a Luau [buffer] object with the specified size. + /// + /// Size limit is 1GB. All bytes will be initialized to zero. + /// + /// [buffer]: https://luau.org/library#buffer-library + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn create_buffer_with_capacity(&self, size: usize) -> Result { + unsafe { Ok(self.lock().create_buffer_with_capacity(size)?.1) } + } + + /// Creates and returns a new empty table. + #[inline] + pub fn create_table(&self) -> Result
{ + self.create_table_with_capacity(0, 0) + } + + /// Creates and returns a new empty table, with the specified capacity. + /// + /// - `narr` is a hint for how many elements the table will have as a sequence. + /// - `nrec` is a hint for how many other elements the table will have. + /// + /// Lua may use these hints to preallocate memory for the new table. + pub fn create_table_with_capacity(&self, narr: usize, nrec: usize) -> Result
{ + unsafe { self.lock().create_table_with_capacity(narr, nrec) } + } + + /// Creates a table and fills it with values from an iterator. + pub fn create_table_from(&self, iter: impl IntoIterator) -> Result
+ where + K: IntoLua, + V: IntoLua, + { + unsafe { self.lock().create_table_from(iter) } + } + + /// Creates a table from an iterator of values, using `1..` as the keys. + pub fn create_sequence_from(&self, iter: impl IntoIterator) -> Result
+ where + T: IntoLua, + { + unsafe { self.lock().create_sequence_from(iter) } + } + + /// Wraps a Rust function or closure, creating a callable Lua function handle to it. + /// + /// The function's return value is always a `Result`: If the function returns `Err`, the error + /// is raised as a Lua error, which can be caught using `(x)pcall` or bubble up to the Rust code + /// that invoked the Lua code. This allows using the `?` operator to propagate errors through + /// intermediate Lua code. + /// + /// If the function returns `Ok`, the contained value will be converted to one or more Lua + /// values. For details on Rust-to-Lua conversions, refer to the [`IntoLua`] and + /// [`IntoLuaMulti`] traits. + /// + /// # Examples + /// + /// Create a function which prints its argument: + /// + /// ``` + /// # use mlua::{Lua, Result}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let greet = lua.create_function(|_, name: String| { + /// println!("Hello, {}!", name); + /// Ok(()) + /// }); + /// # let _ = greet; // used + /// # Ok(()) + /// # } + /// ``` + /// + /// Use tuples to accept multiple arguments: + /// + /// ``` + /// # use mlua::{Lua, Result}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let print_person = lua.create_function(|_, (name, age): (String, u8)| { + /// println!("{} is {} years old!", name, age); + /// Ok(()) + /// }); + /// # let _ = print_person; // used + /// # Ok(()) + /// # } + /// ``` + pub fn create_function(&self, func: F) -> Result + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + (self.lock()).create_callback(Box::new(move |rawlua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, rawlua)?; + func(rawlua.lua(), args)?.push_into_stack_multi(rawlua) + })) + } + + /// Wraps a Rust mutable closure, creating a callable Lua function handle to it. + /// + /// This is a version of [`Lua::create_function`] that accepts a `FnMut` argument. + pub fn create_function_mut(&self, func: F) -> Result + where + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let func = RefCell::new(func); + self.create_function(move |lua, args| { + (*func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?)(lua, args) + }) + } + + /// Wraps a C function, creating a callable Lua function handle to it. + /// + /// # Safety + /// This function is unsafe because provides a way to execute unsafe C function. + pub unsafe fn create_c_function(&self, func: ffi::lua_CFunction) -> Result { + let lua = self.lock(); + if cfg!(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52" + )) { + ffi::lua_pushcfunction(lua.ref_thread(), func); + return Ok(Function(lua.pop_ref_thread())); + } + + // Lua <5.2 requires memory allocation to push a C function + let state = lua.state(); + { + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + if lua.unlikely_memory_error() { + ffi::lua_pushcfunction(state, func); + } else { + protect_lua!(state, 0, 1, |state| ffi::lua_pushcfunction(state, func))?; + } + Ok(Function(lua.pop_ref())) + } + } + + /// Wraps a Rust async function or closure, creating a callable Lua function handle to it. + /// + /// While executing the function Rust will poll the Future and if the result is not ready, + /// call `yield()` passing internal representation of a `Poll::Pending` value. + /// + /// The function must be called inside Lua coroutine ([`Thread`]) to be able to suspend its + /// execution. An executor should be used to poll [`AsyncThread`] and mlua will take a provided + /// Waker in that case. Otherwise noop waker will be used if try to call the function outside of + /// Rust executors. + /// + /// The family of `call_async()` functions takes care about creating [`Thread`]. + /// + /// # Examples + /// + /// Non blocking sleep: + /// + /// ``` + /// use std::time::Duration; + /// use mlua::{Lua, Result}; + /// + /// async fn sleep(_lua: Lua, n: u64) -> Result<&'static str> { + /// tokio::time::sleep(Duration::from_millis(n)).await; + /// Ok("done") + /// } + /// + /// #[tokio::main] + /// async fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.globals().set("sleep", lua.create_async_function(sleep)?)?; + /// let res: String = lua.load("return sleep(...)").call_async(100).await?; // Sleep 100ms + /// assert_eq!(res, "done"); + /// Ok(()) + /// } + /// ``` + /// + /// [`AsyncThread`]: crate::thread::AsyncThread + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn create_async_function(&self, func: F) -> Result + where + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + // In future we should switch to async closures when they are stable to capture `&Lua` + // See https://rust-lang.github.io/rfcs/3668-async-closures.html + (self.lock()).create_async_callback(Box::new(move |rawlua, nargs| unsafe { + let args = match A::from_stack_args(nargs, 1, None, rawlua) { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = func(lua.clone(), args); + Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) }) + })) + } + + /// Wraps a Lua function into a new thread (or coroutine). + /// + /// Equivalent to `coroutine.create`. + pub fn create_thread(&self, func: Function) -> Result { + unsafe { self.lock().create_thread(&func) } + } + + /// Creates a Lua userdata object from a custom userdata type. + /// + /// All userdata instances of the same type `T` shares the same metatable. + #[inline] + pub fn create_userdata(&self, data: T) -> Result + where + T: UserData + MaybeSend + MaybeSync + 'static, + { + unsafe { self.lock().make_userdata(UserDataStorage::new(data)) } + } + + /// Creates a Lua userdata object from a custom serializable userdata type. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + #[inline] + pub fn create_ser_userdata(&self, data: T) -> Result + where + T: UserData + Serialize + MaybeSend + MaybeSync + 'static, + { + unsafe { self.lock().make_userdata(UserDataStorage::new_ser(data)) } + } + + /// Creates a Lua userdata object from a custom Rust type. + /// + /// You can register the type using [`Lua::register_userdata_type`] to add fields or methods + /// _before_ calling this method. + /// Otherwise, the userdata object will have an empty metatable. + /// + /// All userdata instances of the same type `T` shares the same metatable. + #[inline] + pub fn create_any_userdata(&self, data: T) -> Result + where + T: MaybeSend + MaybeSync + 'static, + { + unsafe { self.lock().make_any_userdata(UserDataStorage::new(data)) } + } + + /// Creates a Lua userdata object from a custom serializable Rust type. + /// + /// See [`Lua::create_any_userdata`] for more details. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + #[inline] + pub fn create_ser_any_userdata(&self, data: T) -> Result + where + T: Serialize + MaybeSend + MaybeSync + 'static, + { + unsafe { (self.lock()).make_any_userdata(UserDataStorage::new_ser(data)) } + } + + /// Registers a custom Rust type in Lua to use in userdata objects. + /// + /// This methods provides a way to add fields or methods to userdata objects of a type `T`. + pub fn register_userdata_type(&self, f: impl FnOnce(&mut UserDataRegistry)) -> Result<()> { + let type_id = TypeId::of::(); + let mut registry = UserDataRegistry::new(self); + f(&mut registry); + + let lua = self.lock(); + unsafe { + // Deregister the type if it already registered + if let Some(table_id) = (*lua.extra.get()).registered_userdata_t.remove(&type_id) { + ffi::luaL_unref(lua.state(), ffi::LUA_REGISTRYINDEX, table_id); + } + + // Add to "pending" registration map + ((*lua.extra.get()).pending_userdata_reg).insert(type_id, registry.into_raw()); + } + Ok(()) + } + + /// Create a Lua userdata "proxy" object from a custom userdata type. + /// + /// Proxy object is an empty userdata object that has `T` metatable attached. + /// The main purpose of this object is to provide access to static fields and functions + /// without creating an instance of type `T`. + /// + /// You can get or set uservalues on this object but you cannot borrow any Rust type. + /// + /// # Examples + /// + /// ``` + /// # use mlua::{Lua, Result, UserData, UserDataFields, UserDataMethods}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// struct MyUserData(i32); + /// + /// impl UserData for MyUserData { + /// fn add_fields>(fields: &mut F) { + /// fields.add_field_method_get("val", |_, this| Ok(this.0)); + /// } + /// + /// fn add_methods>(methods: &mut M) { + /// methods.add_function("new", |_, value: i32| Ok(MyUserData(value))); + /// } + /// } + /// + /// lua.globals().set("MyUserData", lua.create_proxy::()?)?; + /// + /// lua.load("assert(MyUserData.new(321).val == 321)").exec()?; + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub fn create_proxy(&self) -> Result + where + T: UserData + 'static, + { + let ud = UserDataProxy::(PhantomData); + unsafe { self.lock().make_userdata(UserDataStorage::new(ud)) } + } + + /// Gets the metatable of a Lua built-in (primitive) type. + /// + /// The metatable is shared by all values of the given type. + /// + /// See [`Lua::set_type_metatable`] for examples. + #[allow(private_bounds)] + pub fn type_metatable(&self) -> Option
{ + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 2); + + if lua.push_primitive_type::() && ffi::lua_getmetatable(state, -1) != 0 { + return Some(Table(lua.pop_ref())); + } + } + None + } + + /// Sets the metatable for a Lua built-in (primitive) type. + /// + /// The metatable will be shared by all values of the given type. + /// + /// # Examples + /// + /// Change metatable for Lua boolean type: + /// + /// ``` + /// # use mlua::{Lua, Result, Function}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let mt = lua.create_table()?; + /// mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { "2" } else { "0" }))?)?; + /// lua.set_type_metatable::(Some(mt)); + /// lua.load("assert(tostring(true) == '2')").exec()?; + /// # Ok(()) + /// # } + /// ``` + #[allow(private_bounds)] + pub fn set_type_metatable(&self, metatable: Option
) { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 2); + + if lua.push_primitive_type::() { + match metatable { + Some(metatable) => lua.push_ref(&metatable.0), + None => ffi::lua_pushnil(state), + } + ffi::lua_setmetatable(state, -2); + } + } + } + + /// Returns a handle to the global environment. + pub fn globals(&self) -> Table { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 1); + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_pushvalue(state, ffi::LUA_GLOBALSINDEX); + Table(lua.pop_ref()) + } + } + + /// Sets the global environment. + /// + /// This will replace the current global environment with the provided `globals` table. + /// + /// For Lua 5.2+ the globals table is stored in the registry and shared between all threads. + /// For Lua 5.1 and Luau the globals table is stored in each thread. + /// + /// Please note that any existing Lua functions have cached global environment and will not + /// see the changes made by this method. + /// To update the environment for existing Lua functions, use [`Function::set_environment`]. + pub fn set_globals(&self, globals: Table) -> Result<()> { + let lua = self.lock(); + let state = lua.state(); + unsafe { + #[cfg(feature = "luau")] + if (*lua.extra.get()).sandboxed { + return Err(Error::runtime("cannot change globals in a sandboxed Lua state")); + } + + let _sg = StackGuard::new(state); + check_stack(state, 1)?; + + lua.push_ref(&globals.0); + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_rawseti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_replace(state, ffi::LUA_GLOBALSINDEX); + } + + Ok(()) + } + + /// Returns a handle to the active `Thread`. + /// + /// For calls to `Lua` this will be the main Lua thread, for parameters given to a callback, + /// this will be whatever Lua thread called the callback. + pub fn current_thread(&self) -> Thread { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 1); + ffi::lua_pushthread(state); + Thread(lua.pop_ref(), state) + } + } + + /// Calls the given function with a [`Scope`] parameter, giving the function the ability to + /// create userdata and callbacks from Rust types that are `!Send` or non-`'static`. + /// + /// The lifetime of any function or userdata created through [`Scope`] lasts only until the + /// completion of this method call, on completion all such created values are automatically + /// dropped and Lua references to them are invalidated. If a script accesses a value created + /// through [`Scope`] outside of this method, a Lua error will result. Since we can ensure the + /// lifetime of values created through [`Scope`], and we know that [`Lua`] cannot be sent to + /// another thread while [`Scope`] is live, it is safe to allow `!Send` data types and whose + /// lifetimes only outlive the scope lifetime. + pub fn scope<'env, R>( + &self, + f: impl for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> Result, + ) -> Result { + f(&Scope::new(self.lock_arc())) + } + + /// Attempts to coerce a Lua value into a String in a manner consistent with Lua's internal + /// behavior. + /// + /// To succeed, the value must be a string (in which case this is a no-op), an integer, or a + /// number. + pub fn coerce_string(&self, v: Value) -> Result> { + Ok(match v { + Value::String(s) => Some(s), + v => unsafe { + let lua = self.lock(); + let state = lua.state(); + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + lua.push_value(&v)?; + let res = if lua.unlikely_memory_error() { + ffi::lua_tolstring(state, -1, ptr::null_mut()) + } else { + protect_lua!(state, 1, 1, |state| { + ffi::lua_tolstring(state, -1, ptr::null_mut()) + })? + }; + if !res.is_null() { + Some(LuaString(lua.pop_ref())) + } else { + None + } + }, + }) + } + + /// Attempts to coerce a Lua value into an integer in a manner consistent with Lua's internal + /// behavior. + /// + /// To succeed, the value must be an integer, a floating point number that has an exact + /// representation as an integer, or a string that can be converted to an integer. Refer to the + /// Lua manual for details. + pub fn coerce_integer(&self, v: Value) -> Result> { + Ok(match v { + Value::Integer(i) => Some(i), + v => unsafe { + let lua = self.lock(); + let state = lua.state(); + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + lua.push_value(&v)?; + let mut isint = 0; + let i = ffi::lua_tointegerx(state, -1, &mut isint); + if isint == 0 { None } else { Some(i) } + }, + }) + } + + /// Attempts to coerce a Lua value into a Number in a manner consistent with Lua's internal + /// behavior. + /// + /// To succeed, the value must be a number or a string that can be converted to a number. Refer + /// to the Lua manual for details. + pub fn coerce_number(&self, v: Value) -> Result> { + Ok(match v { + Value::Number(n) => Some(n), + v => unsafe { + let lua = self.lock(); + let state = lua.state(); + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + lua.push_value(&v)?; + let mut isnum = 0; + let n = ffi::lua_tonumberx(state, -1, &mut isnum); + if isnum == 0 { None } else { Some(n) } + }, + }) + } + + /// Converts a value that implements [`IntoLua`] into a [`Value`] instance. + #[inline] + pub fn pack(&self, t: impl IntoLua) -> Result { + t.into_lua(self) + } + + /// Converts a [`Value`] instance into a value that implements [`FromLua`]. + #[inline] + pub fn unpack(&self, value: Value) -> Result { + T::from_lua(value, self) + } + + /// Converts a value that implements [`IntoLua`] into a [`FromLua`] variant. + #[inline] + pub fn convert(&self, value: impl IntoLua) -> Result { + U::from_lua(value.into_lua(self)?, self) + } + + /// Converts a value that implements [`IntoLuaMulti`] into a [`MultiValue`] instance. + #[inline] + pub fn pack_multi(&self, t: impl IntoLuaMulti) -> Result { + t.into_lua_multi(self) + } + + /// Converts a [`MultiValue`] instance into a value that implements [`FromLuaMulti`]. + #[inline] + pub fn unpack_multi(&self, value: MultiValue) -> Result { + T::from_lua_multi(value, self) + } + + /// Set a value in the Lua registry based on a string key. + /// + /// This value will be available to Rust from all Lua instances which share the same main + /// state. + pub fn set_named_registry_value(&self, key: &str, t: impl IntoLua) -> Result<()> { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 5)?; + + lua.push(t)?; + rawset_field(state, ffi::LUA_REGISTRYINDEX, key) + } + } + + /// Get a value from the Lua registry based on a string key. + /// + /// Any Lua instance which shares the underlying main state may call this method to + /// get a value previously set by [`Lua::set_named_registry_value`]. + pub fn named_registry_value(&self, key: &str) -> Result + where + T: FromLua, + { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + let protect = !lua.unlikely_memory_error(); + push_string(state, key.as_bytes(), protect)?; + ffi::lua_rawget(state, ffi::LUA_REGISTRYINDEX); + + T::from_stack(-1, &lua) + } + } + + /// Removes a named value in the Lua registry. + /// + /// Equivalent to calling [`Lua::set_named_registry_value`] with a value of [`Nil`]. + #[inline] + pub fn unset_named_registry_value(&self, key: &str) -> Result<()> { + self.set_named_registry_value(key, Nil) + } + + /// Place a value in the Lua registry with an auto-generated key. + /// + /// This value will be available to Rust from all Lua instances which share the same main + /// state. + /// + /// Be warned, garbage collection of values held inside the registry is not automatic, see + /// [`RegistryKey`] for more details. + /// However, dropped [`RegistryKey`]s automatically reused to store new values. + pub fn create_registry_value(&self, t: impl IntoLua) -> Result { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + lua.push(t)?; + + let unref_list = (*lua.extra.get()).registry_unref_list.clone(); + + // Check if the value is nil (no need to store it in the registry) + if ffi::lua_isnil(state, -1) != 0 { + return Ok(RegistryKey::new(ffi::LUA_REFNIL, unref_list)); + } + + // Try to reuse previously allocated slot + let free_registry_id = unref_list.lock().as_mut().and_then(|x| x.pop()); + if let Some(registry_id) = free_registry_id { + // It must be safe to replace the value without triggering memory error + ffi::lua_rawseti(state, ffi::LUA_REGISTRYINDEX, registry_id as Integer); + return Ok(RegistryKey::new(registry_id, unref_list)); + } + + // Allocate a new RegistryKey slot + let registry_id = if lua.unlikely_memory_error() { + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) + } else { + protect_lua!(state, 1, 0, |state| { + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) + })? + }; + Ok(RegistryKey::new(registry_id, unref_list)) + } + } + + /// Get a value from the Lua registry by its [`RegistryKey`] + /// + /// Any Lua instance which shares the underlying main state may call this method to get a value + /// previously placed by [`Lua::create_registry_value`]. + pub fn registry_value(&self, key: &RegistryKey) -> Result { + let lua = self.lock(); + if !lua.owns_registry_value(key) { + return Err(Error::MismatchedRegistryKey); + } + + let state = lua.state(); + match key.id() { + ffi::LUA_REFNIL => T::from_lua(Value::Nil, self), + registry_id => unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 1)?; + + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, registry_id as Integer); + T::from_stack(-1, &lua) + }, + } + } + + /// Removes a value from the Lua registry. + /// + /// You may call this function to manually remove a value placed in the registry with + /// [`Lua::create_registry_value`]. In addition to manual [`RegistryKey`] removal, you can also + /// call [`Lua::expire_registry_values`] to automatically remove values from the registry + /// whose [`RegistryKey`]s have been dropped. + pub fn remove_registry_value(&self, key: RegistryKey) -> Result<()> { + let lua = self.lock(); + if !lua.owns_registry_value(&key) { + return Err(Error::MismatchedRegistryKey); + } + + unsafe { ffi::luaL_unref(lua.state(), ffi::LUA_REGISTRYINDEX, key.take()) }; + Ok(()) + } + + /// Replaces a value in the Lua registry by its [`RegistryKey`]. + /// + /// An identifier used in [`RegistryKey`] may possibly be changed to a new value. + /// + /// See [`Lua::create_registry_value`] for more details. + pub fn replace_registry_value(&self, key: &mut RegistryKey, t: impl IntoLua) -> Result<()> { + let lua = self.lock(); + if !lua.owns_registry_value(key) { + return Err(Error::MismatchedRegistryKey); + } + + let t = t.into_lua(self)?; + + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 2)?; + + match (t, key.id()) { + (Value::Nil, ffi::LUA_REFNIL) => { + // Do nothing, no need to replace nil with nil + } + (Value::Nil, registry_id) => { + // Remove the value + ffi::luaL_unref(state, ffi::LUA_REGISTRYINDEX, registry_id); + key.set_id(ffi::LUA_REFNIL); + } + (value, ffi::LUA_REFNIL) => { + // Allocate a new `RegistryKey` + let new_key = self.create_registry_value(value)?; + key.set_id(new_key.take()); + } + (value, registry_id) => { + // It must be safe to replace the value without triggering memory error + lua.push_value(&value)?; + ffi::lua_rawseti(state, ffi::LUA_REGISTRYINDEX, registry_id as Integer); + } + } + } + Ok(()) + } + + /// Returns true if the given [`RegistryKey`] was created by a Lua which shares the + /// underlying main state with this Lua instance. + /// + /// Other than this, methods that accept a [`RegistryKey`] will return + /// [`Error::MismatchedRegistryKey`] if passed a [`RegistryKey`] that was not created with a + /// matching [`Lua`] state. + #[inline] + pub fn owns_registry_value(&self, key: &RegistryKey) -> bool { + self.lock().owns_registry_value(key) + } + + /// Remove any registry values whose [`RegistryKey`]s have all been dropped. + /// + /// Unlike normal handle values, [`RegistryKey`]s do not automatically remove themselves on + /// Drop, but you can call this method to remove any unreachable registry values not + /// manually removed by [`Lua::remove_registry_value`]. + pub fn expire_registry_values(&self) { + let lua = self.lock(); + let state = lua.state(); + unsafe { + let mut unref_list = (*lua.extra.get()).registry_unref_list.lock(); + let unref_list = unref_list.replace(Vec::new()); + for id in mlua_expect!(unref_list, "unref list is not set") { + ffi::luaL_unref(state, ffi::LUA_REGISTRYINDEX, id); + } + } + } + + /// Sets or replaces an application data object of type `T`. + /// + /// Application data could be accessed at any time by using [`Lua::app_data_ref`] or + /// [`Lua::app_data_mut`] methods where `T` is the data type. + /// + /// # Panics + /// + /// Panics if the app data container is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use mlua::{Lua, Result}; + /// + /// fn hello(lua: &Lua, _: ()) -> Result<()> { + /// let mut s = lua.app_data_mut::<&str>().unwrap(); + /// assert_eq!(*s, "hello"); + /// *s = "world"; + /// Ok(()) + /// } + /// + /// fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.set_app_data("hello"); + /// lua.create_function(hello)?.call::<()>(())?; + /// let s = lua.app_data_ref::<&str>().unwrap(); + /// assert_eq!(*s, "world"); + /// Ok(()) + /// } + /// ``` + #[track_caller] + pub fn set_app_data(&self, data: T) -> Option { + let lua = self.lock(); + let extra = unsafe { &*lua.extra.get() }; + extra.app_data.insert(data) + } + + /// Tries to set or replace an application data object of type `T`. + /// + /// Returns: + /// - `Ok(Some(old_data))` if the data object of type `T` was successfully replaced. + /// - `Ok(None)` if the data object of type `T` was successfully inserted. + /// - `Err(data)` if the data object of type `T` was not inserted because the container is + /// currently borrowed. + /// + /// See [`Lua::set_app_data`] for examples. + pub fn try_set_app_data(&self, data: T) -> StdResult, T> { + let lua = self.lock(); + let extra = unsafe { &*lua.extra.get() }; + extra.app_data.try_insert(data) + } + + /// Gets a reference to an application data object stored by [`Lua::set_app_data`] of type + /// `T`. + /// + /// # Panics + /// + /// Panics if the data object of type `T` is currently mutably borrowed. Multiple immutable + /// reads can be taken out at the same time. + #[track_caller] + pub fn app_data_ref(&self) -> Option> { + let guard = self.lock_arc(); + let extra = unsafe { &*guard.extra.get() }; + extra.app_data.borrow(Some(guard)) + } + + /// Tries to get a reference to an application data object stored by [`Lua::set_app_data`] of + /// type `T`. + pub fn try_app_data_ref(&self) -> StdResult>, BorrowError> { + let guard = self.lock_arc(); + let extra = unsafe { &*guard.extra.get() }; + extra.app_data.try_borrow(Some(guard)) + } + + /// Gets a mutable reference to an application data object stored by [`Lua::set_app_data`] of + /// type `T`. + /// + /// # Panics + /// + /// Panics if the data object of type `T` is currently borrowed. + #[track_caller] + pub fn app_data_mut(&self) -> Option> { + let guard = self.lock_arc(); + let extra = unsafe { &*guard.extra.get() }; + extra.app_data.borrow_mut(Some(guard)) + } + + /// Tries to get a mutable reference to an application data object stored by + /// [`Lua::set_app_data`] of type `T`. + pub fn try_app_data_mut(&self) -> StdResult>, BorrowMutError> { + let guard = self.lock_arc(); + let extra = unsafe { &*guard.extra.get() }; + extra.app_data.try_borrow_mut(Some(guard)) + } + + /// Removes an application data of type `T`. + /// + /// # Panics + /// + /// Panics if the app data container is currently borrowed. + #[track_caller] + pub fn remove_app_data(&self) -> Option { + let lua = self.lock(); + let extra = unsafe { &*lua.extra.get() }; + extra.app_data.remove() + } + + /// Returns an internal `Poll::Pending` constant used for executing async callbacks. + /// + /// Every time when [`Future`] is Pending, Lua corotine is suspended with this constant. + #[cfg(feature = "async")] + #[doc(hidden)] + #[inline(always)] + pub fn poll_pending() -> LightUserData { + static ASYNC_POLL_PENDING: u8 = 0; + LightUserData(&ASYNC_POLL_PENDING as *const u8 as *mut std::os::raw::c_void) + } + + #[cfg(feature = "async")] + #[inline(always)] + pub(crate) fn poll_terminate() -> LightUserData { + static ASYNC_POLL_TERMINATE: u8 = 0; + LightUserData(&ASYNC_POLL_TERMINATE as *const u8 as *mut std::os::raw::c_void) + } + + #[cfg(feature = "async")] + #[inline(always)] + pub(crate) fn poll_yield() -> LightUserData { + static ASYNC_POLL_YIELD: u8 = 0; + LightUserData(&ASYNC_POLL_YIELD as *const u8 as *mut std::os::raw::c_void) + } + + /// Suspends the current async function, returning the provided arguments to caller. + /// + /// This function is similar to [`coroutine.yield`] but allow yielding Rust functions + /// and passing values to the caller. + /// Please note that you cannot cross [`Thread`] boundaries (e.g. calling `yield_with` on one + /// thread and resuming on another). + /// + /// # Examples + /// + /// Async iterator: + /// + /// ``` + /// # use mlua::{Lua, Result}; + /// # + /// async fn generator(lua: Lua, _: ()) -> Result<()> { + /// for i in 0..10 { + /// lua.yield_with::<()>(i).await?; + /// } + /// Ok(()) + /// } + /// + /// fn main() -> Result<()> { + /// let lua = Lua::new(); + /// lua.globals().set("generator", lua.create_async_function(generator)?)?; + /// + /// lua.load(r#" + /// local n = 0 + /// for i in coroutine.wrap(generator) do + /// n = n + i + /// end + /// assert(n == 45) + /// "#) + /// .exec() + /// } + /// ``` + /// + /// Exchange values on yield: + /// + /// ``` + /// # use mlua::{Lua, Result, Value}; + /// # + /// async fn pingpong(lua: Lua, mut val: i32) -> Result<()> { + /// loop { + /// val = lua.yield_with::(val).await? + 1; + /// } + /// Ok(()) + /// } + /// + /// # fn main() -> Result<()> { + /// let lua = Lua::new(); + /// + /// let co = lua.create_thread(lua.create_async_function(pingpong)?)?; + /// assert_eq!(co.resume::(1)?, 1); + /// assert_eq!(co.resume::(2)?, 3); + /// assert_eq!(co.resume::(3)?, 4); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// [`coroutine.yield`]: https://www.lua.org/manual/5.4/manual.html#pdf-coroutine.yield + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub async fn yield_with(&self, args: impl IntoLuaMulti) -> Result { + let mut args = Some(args.into_lua_multi(self)?); + future::poll_fn(move |_cx| match args.take() { + Some(args) => unsafe { + let lua = self.lock(); + lua.push(Self::poll_yield())?; // yield marker + if args.len() <= 1 { + lua.push(args.front())?; + } else { + lua.push(lua.create_sequence_from(&args)?)?; + } + lua.push(args.len())?; + Poll::Pending + }, + None => unsafe { + let lua = self.lock(); + let state = lua.state(); + let top = ffi::lua_gettop(state); + if top == 0 || ffi::lua_type(state, 1) != ffi::LUA_TUSERDATA { + // This must be impossible scenario if used correctly + return Poll::Ready(R::from_stack_multi(0, &lua)); + } + let _sg = StackGuard::with_top(state, 1); + Poll::Ready(R::from_stack_multi(top - 1, &lua)) + }, + }) + .await + } + + /// Returns a weak reference to the Lua instance. + /// + /// This is useful for creating a reference to the Lua instance that does not prevent it from + /// being deallocated. + #[inline(always)] + pub fn weak(&self) -> WeakLua { + WeakLua(XRc::downgrade(&self.raw)) + } + + #[cfg(not(feature = "luau"))] + fn disable_c_modules(&self) -> Result<()> { + let package: Table = self.globals().get("package")?; + + package.set( + "loadlib", + self.create_function(|_, ()| -> Result<()> { + Err(Error::SafetyError( + "package.loadlib is disabled in safe mode".to_string(), + )) + })?, + )?; + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + let searchers: Table = package.get("searchers")?; + #[cfg(any(feature = "lua51", feature = "luajit"))] + let searchers: Table = package.get("loaders")?; + + let loader = self.create_function(|_, ()| Ok("\n\tcan't load C modules in safe mode"))?; + + // The third and fourth searchers looks for a loader as a C library + searchers.raw_set(3, loader)?; + if searchers.raw_len() >= 4 { + searchers.raw_remove(4)?; + } + + Ok(()) + } + + #[inline(always)] + pub(crate) fn lock(&self) -> ReentrantMutexGuard<'_, RawLua> { + let rawlua = self.raw.lock(); + #[cfg(feature = "luau")] + if unsafe { (*rawlua.extra.get()).running_gc } { + panic!("Luau VM is suspended while GC is running"); + } + rawlua + } + + #[inline(always)] + pub(crate) fn lock_arc(&self) -> LuaGuard { + LuaGuard(self.raw.lock_arc()) + } + + /// Returns a handle to the unprotected Lua state without any synchronization. + /// + /// This is useful where we know that the lock is already held by the caller. + #[cfg(feature = "async")] + #[inline(always)] + pub(crate) unsafe fn raw_lua(&self) -> &RawLua { + &*self.raw.data_ptr() + } +} + +impl WeakLua { + #[track_caller] + #[inline(always)] + pub(crate) fn lock(&self) -> LuaGuard { + let guard = LuaGuard::new(self.0.upgrade().expect("Lua instance is destroyed")); + #[cfg(feature = "luau")] + if unsafe { (*guard.extra.get()).running_gc } { + panic!("Luau VM is suspended while GC is running"); + } + guard + } + + #[inline(always)] + pub(crate) fn try_lock(&self) -> Option { + Some(LuaGuard::new(self.0.upgrade()?)) + } + + /// Upgrades the weak Lua reference to a strong reference. + /// + /// # Panics + /// + /// Panics if the Lua instance is destroyed. + #[track_caller] + #[inline(always)] + pub fn upgrade(&self) -> Lua { + Lua { + raw: self.0.upgrade().expect("Lua instance is destroyed"), + collect_garbage: false, + } + } + + /// Tries to upgrade the weak Lua reference to a strong reference. + /// + /// Returns `None` if the Lua instance is destroyed. + #[inline(always)] + pub fn try_upgrade(&self) -> Option { + Some(Lua { + raw: self.0.upgrade()?, + collect_garbage: false, + }) + } +} + +impl PartialEq for WeakLua { + fn eq(&self, other: &Self) -> bool { + XWeak::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for WeakLua {} + +impl LuaGuard { + #[cfg(feature = "send")] + pub(crate) fn new(handle: XRc>) -> Self { + LuaGuard(handle.lock_arc()) + } + + #[cfg(not(feature = "send"))] + pub(crate) fn new(handle: XRc>) -> Self { + LuaGuard(handle.into_lock_arc()) + } +} + +impl Deref for LuaGuard { + type Target = RawLua; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub(crate) mod extra; +mod raw; +pub(crate) mod util; + +#[cfg(test)] +mod assertions { + use super::*; + + // Lua has lots of interior mutability, should not be RefUnwindSafe + static_assertions::assert_not_impl_any!(Lua: std::panic::RefUnwindSafe); + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(Lua: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(Lua: Send, Sync); +} diff --git a/src/state/extra.rs b/src/state/extra.rs new file mode 100644 index 00000000..d761c9cb --- /dev/null +++ b/src/state/extra.rs @@ -0,0 +1,294 @@ +use std::any::TypeId; +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::os::raw::{c_int, c_void}; +use std::ptr; +use std::rc::Rc; +use std::sync::Arc; + +use parking_lot::Mutex; +use rustc_hash::FxHashMap; + +use crate::error::Result; +use crate::state::RawLua; +use crate::stdlib::StdLib; +use crate::types::{AppData, ReentrantMutex, XRc}; +use crate::userdata::RawUserDataRegistry; +use crate::util::{TypeKey, WrappedFailure, get_internal_metatable, push_internal_userdata}; + +#[cfg(any(feature = "luau", doc))] +use crate::chunk::Compiler; + +#[cfg(feature = "async")] +use {futures_util::task::noop_waker_ref, std::ptr::NonNull, std::task::Waker}; + +use super::{Lua, WeakLua}; + +// Unique key to store `ExtraData` in the registry +static EXTRA_REGISTRY_KEY: u8 = 0; + +const WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY: usize = 64; +const REF_STACK_RESERVE: c_int = 3; + +/// Data associated with the Lua state. +pub(crate) struct ExtraData { + pub(super) lua: MaybeUninit, + pub(super) weak: MaybeUninit, + pub(super) owned: bool, + + pub(super) pending_userdata_reg: FxHashMap, + pub(super) registered_userdata_t: FxHashMap, + pub(super) registered_userdata_mt: FxHashMap<*const c_void, Option>, + pub(super) last_checked_userdata_mt: (*const c_void, Option), + + // When Lua instance dropped, setting `None` would prevent collecting `RegistryKey`s + pub(super) registry_unref_list: Arc>>>, + + // Containers to store arbitrary data (extensions) + pub(super) app_data: AppData, + pub(super) app_data_priv: AppData, + + pub(super) safe: bool, + pub(super) libs: StdLib, + // Used in module mode + pub(super) skip_memory_check: bool, + + // Auxiliary thread to store references + pub(super) ref_thread: *mut ffi::lua_State, + pub(super) ref_stack_size: c_int, + pub(super) ref_stack_top: c_int, + pub(super) ref_free: Vec, + + // Pool of `WrappedFailure` enums in the ref thread (as userdata) + pub(super) wrapped_failure_pool: Vec, + pub(super) wrapped_failure_top: usize, + // Pool of `Thread`s (coroutines) for async execution + #[cfg(feature = "async")] + pub(super) thread_pool: Vec, + + // Address of `WrappedFailure` metatable + pub(super) wrapped_failure_mt_ptr: *const c_void, + + // Waker for polling futures + #[cfg(feature = "async")] + pub(super) waker: NonNull, + + #[cfg(not(feature = "luau"))] + pub(super) hook_callback: Option, + #[cfg(not(feature = "luau"))] + pub(super) hook_triggers: crate::debug::HookTriggers, + #[cfg(any(feature = "lua55", feature = "lua54"))] + pub(super) warn_callback: Option, + #[cfg(feature = "luau")] + pub(super) interrupt_callback: Option, + #[cfg(feature = "luau")] + pub(super) thread_creation_callback: Option, + #[cfg(feature = "luau")] + pub(super) thread_collection_callback: Option, + + #[cfg(feature = "luau")] + pub(crate) running_gc: bool, + #[cfg(feature = "luau")] + pub(crate) sandboxed: bool, + #[cfg(feature = "luau")] + pub(super) compiler: Option, + #[cfg(feature = "luau-jit")] + pub(super) enable_jit: bool, + #[cfg(feature = "luau")] + pub(crate) mem_categories: Vec, +} + +impl Drop for ExtraData { + fn drop(&mut self) { + unsafe { + if !self.owned { + self.lua.assume_init_drop(); + } + + self.weak.assume_init_drop(); + } + *self.registry_unref_list.lock() = None; + } +} + +static EXTRA_TYPE_KEY: u8 = 0; + +impl TypeKey for XRc> { + #[inline(always)] + fn type_key() -> *const c_void { + &EXTRA_TYPE_KEY as *const u8 as *const c_void + } +} + +impl ExtraData { + // Index of `error_traceback` function in auxiliary thread stack + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + pub(super) const ERROR_TRACEBACK_IDX: c_int = 1; + + pub(super) unsafe fn init(state: *mut ffi::lua_State, owned: bool) -> XRc> { + // Create ref stack thread and place it in the registry to prevent it + // from being garbage collected. + let ref_thread = mlua_expect!( + protect_lua!(state, 0, 0, |state| { + let thread = ffi::lua_newthread(state); + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX); + thread + }), + "Error while creating ref thread", + ); + + let wrapped_failure_mt_ptr = { + get_internal_metatable::(state); + let ptr = ffi::lua_topointer(state, -1); + ffi::lua_pop(state, 1); + ptr + }; + + // Store `error_traceback` function on the ref stack + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + { + ffi::lua_pushcfunction(ref_thread, crate::util::error_traceback); + assert_eq!(ffi::lua_gettop(ref_thread), Self::ERROR_TRACEBACK_IDX); + } + + #[allow(clippy::arc_with_non_send_sync)] + let extra = XRc::new(UnsafeCell::new(ExtraData { + lua: MaybeUninit::uninit(), + weak: MaybeUninit::uninit(), + owned, + pending_userdata_reg: FxHashMap::default(), + registered_userdata_t: FxHashMap::default(), + registered_userdata_mt: FxHashMap::default(), + last_checked_userdata_mt: (ptr::null(), None), + registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))), + app_data: AppData::default(), + app_data_priv: AppData::default(), + safe: false, + libs: StdLib::NONE, + skip_memory_check: false, + ref_thread, + // We need some reserved stack space to move values in and out of the ref stack. + ref_stack_size: ffi::LUA_MINSTACK - REF_STACK_RESERVE, + ref_stack_top: ffi::lua_gettop(ref_thread), + ref_free: Vec::new(), + wrapped_failure_pool: Vec::with_capacity(WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY), + wrapped_failure_top: 0, + #[cfg(feature = "async")] + thread_pool: Vec::new(), + wrapped_failure_mt_ptr, + #[cfg(feature = "async")] + waker: NonNull::from(noop_waker_ref()), + #[cfg(not(feature = "luau"))] + hook_callback: None, + #[cfg(not(feature = "luau"))] + hook_triggers: Default::default(), + #[cfg(any(feature = "lua55", feature = "lua54"))] + warn_callback: None, + #[cfg(feature = "luau")] + interrupt_callback: None, + #[cfg(feature = "luau")] + thread_creation_callback: None, + #[cfg(feature = "luau")] + thread_collection_callback: None, + #[cfg(feature = "luau")] + sandboxed: false, + #[cfg(feature = "luau")] + compiler: None, + #[cfg(feature = "luau-jit")] + enable_jit: true, + #[cfg(feature = "luau")] + running_gc: false, + #[cfg(feature = "luau")] + mem_categories: vec![std::ffi::CString::new("main").unwrap()], + })); + + // Store it in the registry + mlua_expect!(Self::store(&extra, state), "Error while storing extra data"); + + extra + } + + pub(super) unsafe fn set_lua(&mut self, raw: &XRc>) { + self.lua.write(Lua { + raw: XRc::clone(raw), + collect_garbage: false, + }); + self.weak.write(WeakLua(XRc::downgrade(raw))); + } + + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + // In the main app we can use `lua_callbacks` to access ExtraData + return (*ffi::lua_callbacks(state)).userdata as *mut _; + } + + let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { + // `ExtraData` can be null only when Lua state is foreign. + // This case in used in `Lua::try_from_ptr()`. + ffi::lua_pop(state, 1); + return ptr::null_mut(); + } + let extra_ptr = ffi::lua_touserdata(state, -1) as *mut Rc>; + ffi::lua_pop(state, 1); + (*extra_ptr).get() + } + + unsafe fn store(extra: &XRc>, state: *mut ffi::lua_State) -> Result<()> { + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + (*ffi::lua_callbacks(state)).userdata = extra.get() as *mut _; + return Ok(()); + } + + push_internal_userdata(state, XRc::clone(extra), true)?; + protect_lua!(state, 1, 0, fn(state) { + let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key); + }) + } + + #[inline(always)] + pub(super) unsafe fn lua(&self) -> &Lua { + self.lua.assume_init_ref() + } + + #[inline(always)] + pub(crate) unsafe fn raw_lua(&self) -> &RawLua { + &*self.lua.assume_init_ref().raw.data_ptr() + } + + #[inline(always)] + pub(super) unsafe fn weak(&self) -> &WeakLua { + self.weak.assume_init_ref() + } + + /// Pops a reference from top of the auxiliary stack and move it to a first free slot. + pub(super) unsafe fn ref_stack_pop(&mut self) -> c_int { + if let Some(free) = self.ref_free.pop() { + ffi::lua_replace(self.ref_thread, free); + return free; + } + + // Try to grow max stack size + if self.ref_stack_top >= self.ref_stack_size { + let mut inc = self.ref_stack_size; // Try to double stack size + while inc > 0 && ffi::lua_checkstack(self.ref_thread, inc + REF_STACK_RESERVE) == 0 { + inc /= 2; + } + if inc == 0 { + // Pop item on top of the stack to avoid stack leaking and successfully run destructors + // during unwinding. + ffi::lua_pop(self.ref_thread, 1); + let top = self.ref_stack_top; + // It is a user error to create too many references to exhaust the Lua max stack size + // for the ref thread. + panic!("cannot create a Lua reference, out of auxiliary stack space (used {top} slots)"); + } + self.ref_stack_size += inc; + } + self.ref_stack_top += 1; + self.ref_stack_top + } +} diff --git a/src/state/raw.rs b/src/state/raw.rs new file mode 100644 index 00000000..dcc860a3 --- /dev/null +++ b/src/state/raw.rs @@ -0,0 +1,1637 @@ +use std::any::TypeId; +use std::cell::{Cell, UnsafeCell}; +use std::ffi::CStr; +use std::mem; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::resume_unwind; +use std::ptr::{self, NonNull}; +use std::sync::Arc; + +use crate::chunk::ChunkMode; +use crate::error::{Error, Result}; +use crate::function::Function; +use crate::memory::{ALLOCATOR, MemoryState}; +use crate::state::util::callback_error_ext; +use crate::stdlib::StdLib; +use crate::string::LuaString; +use crate::table::Table; +use crate::thread::Thread; +use crate::traits::{FromLua, IntoLua}; +use crate::types::{ + AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, + LuaType, MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, +}; +use crate::userdata::{ + AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, + init_userdata_metatable, +}; +use crate::util::{ + StackGuard, WrappedFailure, assert_stack, check_stack, get_destructed_userdata_metatable, + get_internal_userdata, get_main_state, get_metatable_ptr, get_userdata, init_error_registry, + init_internal_metatable, pop_error, push_internal_userdata, push_string, push_table, push_userdata, + rawset_field, safe_pcall, safe_xpcall, short_type_name, +}; +use crate::value::{Nil, Value}; + +use super::extra::ExtraData; +use super::{Lua, LuaOptions, WeakLua}; + +#[cfg(not(feature = "luau"))] +use crate::{ + debug::Debug, + types::{HookCallback, HookKind, VmState}, +}; + +#[cfg(feature = "async")] +use { + crate::multi::MultiValue, + crate::traits::FromLuaMulti, + crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}, + std::task::{Context, Poll, Waker}, +}; + +/// An internal Lua struct which holds a raw Lua state. +#[doc(hidden)] +pub struct RawLua { + // The state is dynamic and depends on context + pub(super) state: Cell<*mut ffi::lua_State>, + pub(super) main_state: Option>, + pub(super) extra: XRc>, + owned: bool, +} + +impl Drop for RawLua { + fn drop(&mut self) { + unsafe { + if !self.owned { + return; + } + + let mem_state = MemoryState::get(self.main_state()); + + #[cfg(feature = "luau")] + { + // Reset any callbacks + (*ffi::lua_callbacks(self.main_state())).interrupt = None; + (*ffi::lua_callbacks(self.main_state())).userthread = None; + } + + ffi::lua_close(self.main_state()); + + // Deallocate `MemoryState` + if !mem_state.is_null() { + drop(Box::from_raw(mem_state)); + } + } + } +} + +#[cfg(feature = "send")] +unsafe impl Send for RawLua {} + +impl RawLua { + #[inline(always)] + pub(crate) fn lua(&self) -> &Lua { + unsafe { (*self.extra.get()).lua() } + } + + #[inline(always)] + pub(crate) fn weak(&self) -> &WeakLua { + unsafe { (*self.extra.get()).weak() } + } + + /// Returns a pointer to the current Lua state. + /// + /// The pointer refers to the active Lua coroutine and depends on the context. + #[inline(always)] + pub fn state(&self) -> *mut ffi::lua_State { + self.state.get() + } + + #[inline(always)] + pub(crate) fn main_state(&self) -> *mut ffi::lua_State { + self.main_state + .map(|state| state.as_ptr()) + .unwrap_or_else(|| self.state()) + } + + #[inline(always)] + pub(crate) fn ref_thread(&self) -> *mut ffi::lua_State { + unsafe { (*self.extra.get()).ref_thread } + } + + pub(super) unsafe fn new(libs: StdLib, options: &LuaOptions) -> XRc> { + let mem_state: *mut MemoryState = Box::into_raw(Box::default()); + #[cfg(feature = "lua55")] + let mut state = { + let seed = ffi::luaL_makeseed(ptr::null_mut()); + ffi::lua_newstate(ALLOCATOR, mem_state as *mut c_void, seed) + }; + #[cfg(not(feature = "lua55"))] + let mut state = ffi::lua_newstate(ALLOCATOR, mem_state as *mut c_void); + // If state is null then switch to Lua internal allocator + if state.is_null() { + drop(Box::from_raw(mem_state)); + state = ffi::luaL_newstate(); + } + assert!(!state.is_null(), "Failed to create a Lua VM"); + + ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1); + ffi::lua_pop(state, 1); + + // Init Luau code generator (jit) + #[cfg(feature = "luau-jit")] + if ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_create(state); + } + + let rawlua = Self::init_from_ptr(state, true); + let extra = rawlua.lock().extra.get(); + + mlua_expect!( + load_std_libs(state, libs), + "Error during loading standard libraries" + ); + (*extra).libs |= libs; + + if !options.catch_rust_panics { + mlua_expect!( + (|| -> Result<()> { + let _sg = StackGuard::new(state); + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS); + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_pushvalue(state, ffi::LUA_GLOBALSINDEX); + + ffi::lua_pushcfunction(state, safe_pcall); + rawset_field(state, -2, "pcall")?; + + ffi::lua_pushcfunction(state, safe_xpcall); + rawset_field(state, -2, "xpcall")?; + + Ok(()) + })(), + "Error during applying option `catch_rust_panics`" + ) + } + + #[cfg(feature = "async")] + if options.thread_pool_size > 0 { + (*extra).thread_pool.reserve_exact(options.thread_pool_size); + } + + rawlua + } + + pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State, owned: bool) -> XRc> { + assert!(!state.is_null(), "Lua state is NULL"); + if let Some(lua) = Self::try_from_ptr(state) { + return lua; + } + + let main_state = get_main_state(state).unwrap_or(state); + let main_state_top = ffi::lua_gettop(main_state); + + mlua_expect!( + (|state| { + init_error_registry(state)?; + + // Create the internal metatables and store them in the registry + // to prevent from being garbage collected. + + init_internal_metatable::>>(state, None)?; + init_internal_metatable::(state, None)?; + init_internal_metatable::(state, None)?; + #[cfg(not(feature = "luau"))] + init_internal_metatable::(state, None)?; + #[cfg(feature = "async")] + { + init_internal_metatable::(state, None)?; + init_internal_metatable::(state, None)?; + init_internal_metatable::(state, None)?; + init_internal_metatable::>(state, None)?; + } + + // Init serde metatables + #[cfg(feature = "serde")] + crate::serde::init_metatables(state)?; + + Ok::<_, Error>(()) + })(main_state), + "Error during Lua initialization", + ); + + // Init ExtraData + let extra = ExtraData::init(main_state, owned); + + // Register `DestructedUserdata` type + get_destructed_userdata_metatable(main_state); + let destructed_mt_ptr = ffi::lua_topointer(main_state, -1); + let destructed_ud_typeid = TypeId::of::(); + (*extra.get()) + .registered_userdata_mt + .insert(destructed_mt_ptr, Some(destructed_ud_typeid)); + ffi::lua_pop(main_state, 1); + + mlua_debug_assert!( + ffi::lua_gettop(main_state) == main_state_top, + "stack leak during creation" + ); + assert_stack(main_state, ffi::LUA_MINSTACK); + + #[allow(clippy::arc_with_non_send_sync)] + let rawlua = XRc::new(ReentrantMutex::new(RawLua { + state: Cell::new(state), + // Make sure that we don't store current state as main state (if it's not available) + main_state: get_main_state(state).and_then(NonNull::new), + extra: XRc::clone(&extra), + owned, + })); + (*extra.get()).set_lua(&rawlua); + if owned { + // If Lua state is managed by us, then make internal `RawLua` reference "weak" + XRc::decrement_strong_count(XRc::as_ptr(&rawlua)); + } else { + // If Lua state is not managed by us, then keep internal `RawLua` reference "strong" + // but `Extra` reference weak (it will be collected from registry at lua_close time) + XRc::decrement_strong_count(XRc::as_ptr(&extra)); + } + + rawlua + } + + unsafe fn try_from_ptr(state: *mut ffi::lua_State) -> Option>> { + match ExtraData::get(state) { + extra if extra.is_null() => None, + extra => Some(XRc::clone(&(*extra).lua().raw)), + } + } + + /// Marks the Lua state as safe. + #[inline(always)] + pub(super) fn mark_safe(&self) { + unsafe { (*self.extra.get()).safe = true }; + } + + /// Loads the specified subset of the standard libraries into an existing Lua state. + /// + /// Use the [`StdLib`] flags to specify the libraries you want to load. + /// + /// [`StdLib`]: crate::StdLib + pub(super) unsafe fn load_std_libs(&self, libs: StdLib) -> Result<()> { + let is_safe = (*self.extra.get()).safe; + + #[cfg(not(feature = "luau"))] + if is_safe && libs.contains(StdLib::DEBUG) { + return Err(Error::SafetyError( + "the unsafe `debug` module can't be loaded in safe mode".to_string(), + )); + } + #[cfg(feature = "luajit")] + if is_safe && libs.contains(StdLib::FFI) { + return Err(Error::SafetyError( + "the unsafe `ffi` module can't be loaded in safe mode".to_string(), + )); + } + + let res = load_std_libs(self.main_state(), libs); + + // If `package` library loaded into a safe lua state then disable C modules + #[cfg(not(feature = "luau"))] + if is_safe { + let curr_libs = (*self.extra.get()).libs; + if (curr_libs ^ (curr_libs | libs)).contains(StdLib::PACKAGE) { + mlua_expect!(self.lua().disable_c_modules(), "Error disabling C modules"); + } + } + #[cfg(feature = "luau")] + let _ = is_safe; + unsafe { (*self.extra.get()).libs |= libs }; + + res + } + + /// Private version of [`Lua::try_set_app_data`] + #[inline] + pub(crate) fn set_priv_app_data(&self, data: T) -> Option { + let extra = unsafe { &*self.extra.get() }; + extra.app_data_priv.insert(data) + } + + /// Private version of [`Lua::app_data_ref`] + #[track_caller] + #[inline] + pub(crate) fn priv_app_data_ref(&self) -> Option> { + let extra = unsafe { &*self.extra.get() }; + extra.app_data_priv.borrow(None) + } + + /// Private version of [`Lua::app_data_mut`] + #[track_caller] + #[inline] + pub(crate) fn priv_app_data_mut(&self) -> Option> { + let extra = unsafe { &*self.extra.get() }; + extra.app_data_priv.borrow_mut(None) + } + + /// See [`Lua::create_registry_value`] + #[inline] + pub(crate) fn owns_registry_value(&self, key: &RegistryKey) -> bool { + let registry_unref_list = unsafe { &(*self.extra.get()).registry_unref_list }; + Arc::ptr_eq(&key.unref_list, registry_unref_list) + } + + pub(crate) fn load_chunk( + &self, + name: Option<&CStr>, + env: Option<&Table>, + mode: Option, + source: &[u8], + ) -> Result { + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + let name = name.map(CStr::as_ptr).unwrap_or(ptr::null()); + let mode = match mode { + Some(ChunkMode::Binary) => cstr!("b"), + Some(ChunkMode::Text) => cstr!("t"), + None => cstr!("bt"), + }; + let status = if self.unlikely_memory_error() { + self.load_chunk_inner(state, name, env, mode, source) + } else { + // Luau and Lua 5.2 can trigger an exception during chunk loading + protect_lua!(state, 0, 1, |state| { + self.load_chunk_inner(state, name, env, mode, source) + })? + }; + match status { + ffi::LUA_OK => Ok(Function(self.pop_ref())), + err => Err(pop_error(state, err)), + } + } + } + + pub(crate) unsafe fn load_chunk_inner( + &self, + state: *mut ffi::lua_State, + name: *const c_char, + env: Option<&Table>, + mode: *const c_char, + source: &[u8], + ) -> c_int { + let status = ffi::luaL_loadbufferenv( + state, + source.as_ptr() as *const c_char, + source.len(), + name, + mode, + match env { + Some(env) => { + self.push_ref(&env.0); + -1 + } + _ => 0, + }, + ); + #[cfg(feature = "luau-jit")] + if status == ffi::LUA_OK { + if (*self.extra.get()).enable_jit && ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_compile(state, -1); + } + } + status + } + + /// Sets a hook for a thread (coroutine). + #[cfg(not(feature = "luau"))] + pub(crate) unsafe fn set_thread_hook( + &self, + thread_state: *mut ffi::lua_State, + hook: HookKind, + ) -> Result<()> { + // Key to store hooks in the registry + const HOOKS_KEY: *const c_char = cstr!("__mlua_hooks"); + + unsafe fn process_status(state: *mut ffi::lua_State, event: c_int, status: VmState) { + match status { + VmState::Continue => {} + VmState::Yield => { + // Only count and line events can yield + if event == ffi::LUA_HOOKCOUNT || event == ffi::LUA_HOOKLINE { + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + if ffi::lua_isyieldable(state) != 0 { + ffi::lua_yield(state, 0); + } + #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] + { + ffi::lua_pushliteral(state, c"attempt to yield from a hook"); + ffi::lua_error(state); + } + } + } + } + } + + unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { + let status = callback_error_ext(state, ptr::null_mut(), false, move |extra, _| { + match (*extra).hook_callback.clone() { + Some(hook_callback) => { + let rawlua = (*extra).raw_lua(); + let debug = Debug::new(rawlua, 0, ar); + hook_callback((*extra).lua(), &debug) + } + None => { + ffi::lua_sethook(state, None, 0, 0); + Ok(VmState::Continue) + } + } + }); + process_status(state, (*ar).event, status); + } + + unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { + let top = ffi::lua_gettop(state); + let mut hook_callback_ptr = ptr::null(); + ffi::luaL_checkstack(state, 3, ptr::null()); + if ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == ffi::LUA_TTABLE { + ffi::lua_pushthread(state); + if ffi::lua_rawget(state, -2) == ffi::LUA_TUSERDATA { + hook_callback_ptr = get_internal_userdata::(state, -1, ptr::null()); + } + } + ffi::lua_settop(state, top); + if hook_callback_ptr.is_null() { + ffi::lua_sethook(state, None, 0, 0); + return; + } + + let status = callback_error_ext(state, ptr::null_mut(), false, |extra, _| { + let rawlua = (*extra).raw_lua(); + let debug = Debug::new(rawlua, 0, ar); + let hook_callback = (*hook_callback_ptr).clone(); + hook_callback((*extra).lua(), &debug) + }); + process_status(state, (*ar).event, status) + } + + let (triggers, callback) = match hook { + HookKind::Global if (*self.extra.get()).hook_callback.is_none() => { + return Ok(()); + } + HookKind::Global => { + let triggers = (*self.extra.get()).hook_triggers; + let (mask, count) = (triggers.mask(), triggers.count()); + ffi::lua_sethook(thread_state, Some(global_hook_proc), mask, count); + return Ok(()); + } + HookKind::Thread(triggers, callback) => (triggers, callback), + }; + + // Hooks for threads stored in the registry (in a weak table) + let state = self.state(); + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + protect_lua!(state, 0, 0, |state| { + if ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == 0 { + // Table just created, initialize it + ffi::lua_pushliteral(state, c"k"); + ffi::lua_setfield(state, -2, cstr!("__mode")); // hooktable.__mode = "k" + ffi::lua_pushvalue(state, -1); + ffi::lua_setmetatable(state, -2); // metatable(hooktable) = hooktable + } + + ffi::lua_pushthread(thread_state); + ffi::lua_xmove(thread_state, state, 1); // key (thread) + let _ = push_internal_userdata(state, callback, false); // value (hook callback) + ffi::lua_rawset(state, -3); // hooktable[thread] = hook callback + })?; + + ffi::lua_sethook(thread_state, Some(hook_proc), triggers.mask(), triggers.count()); + + Ok(()) + } + + /// See [`Lua::create_string`] + pub(crate) unsafe fn create_string(&self, s: &[u8]) -> Result { + let state = self.state(); + if self.unlikely_memory_error() { + push_string(state, s, false)?; + return Ok(LuaString(self.pop_ref())); + } + + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + push_string(state, s, true)?; + Ok(LuaString(self.pop_ref())) + } + + /// Creates an external string, that is, a string that uses memory not managed by Lua. + /// + /// Modifies the input data to add `\0` terminator. + #[cfg(feature = "lua55")] + pub(crate) unsafe fn create_external_string(&self, bytes: Vec) -> Result { + let state = self.state(); + if self.unlikely_memory_error() { + crate::util::push_external_string(state, bytes, false)?; + return Ok(LuaString(self.pop_ref())); + } + + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + crate::util::push_external_string(state, bytes, true)?; + Ok(LuaString(self.pop_ref())) + } + + #[cfg(feature = "luau")] + pub(crate) unsafe fn create_buffer_with_capacity(&self, size: usize) -> Result<(*mut u8, crate::Buffer)> { + let state = self.state(); + if self.unlikely_memory_error() { + let ptr = crate::util::push_buffer(state, size, false)?; + return Ok((ptr, crate::Buffer(self.pop_ref()))); + } + + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + let ptr = crate::util::push_buffer(state, size, true)?; + Ok((ptr, crate::Buffer(self.pop_ref()))) + } + + /// See [`Lua::create_table_with_capacity`] + pub(crate) unsafe fn create_table_with_capacity(&self, narr: usize, nrec: usize) -> Result
{ + let state = self.state(); + if self.unlikely_memory_error() { + push_table(state, narr, nrec, false)?; + return Ok(Table(self.pop_ref())); + } + + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + push_table(state, narr, nrec, true)?; + Ok(Table(self.pop_ref())) + } + + /// See [`Lua::create_table_from`] + pub(crate) unsafe fn create_table_from(&self, iter: I) -> Result
+ where + I: IntoIterator, + K: IntoLua, + V: IntoLua, + { + let state = self.state(); + let _sg = StackGuard::new(state); + check_stack(state, 6)?; + + let iter = iter.into_iter(); + let lower_bound = iter.size_hint().0; + let protect = !self.unlikely_memory_error(); + push_table(state, 0, lower_bound, protect)?; + for (k, v) in iter { + self.push(k)?; + self.push(v)?; + if protect { + protect_lua!(state, 3, 1, fn(state) ffi::lua_rawset(state, -3))?; + } else { + ffi::lua_rawset(state, -3); + } + } + + Ok(Table(self.pop_ref())) + } + + /// See [`Lua::create_sequence_from`] + pub(crate) unsafe fn create_sequence_from(&self, iter: I) -> Result
+ where + T: IntoLua, + I: IntoIterator, + { + let state = self.state(); + let _sg = StackGuard::new(state); + check_stack(state, 5)?; + + let iter = iter.into_iter(); + let lower_bound = iter.size_hint().0; + let protect = !self.unlikely_memory_error(); + push_table(state, lower_bound, 0, protect)?; + for (i, v) in iter.enumerate() { + self.push(v)?; + if protect { + protect_lua!(state, 2, 1, |state| { + ffi::lua_rawseti(state, -2, (i + 1) as Integer); + })?; + } else { + ffi::lua_rawseti(state, -2, (i + 1) as Integer); + } + } + + Ok(Table(self.pop_ref())) + } + + /// Wraps a Lua function into a new thread (or coroutine). + /// + /// Takes function by reference. + pub(crate) unsafe fn create_thread(&self, func: &Function) -> Result { + let state = self.state(); + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + let protect = !self.unlikely_memory_error(); + #[cfg(feature = "luau")] + let protect = protect || (*self.extra.get()).thread_creation_callback.is_some(); + + let thread_state = if !protect { + ffi::lua_newthread(state) + } else { + protect_lua!(state, 0, 1, |state| ffi::lua_newthread(state))? + }; + + // Inherit global hook if set + #[cfg(not(feature = "luau"))] + self.set_thread_hook(thread_state, HookKind::Global)?; + + let thread = Thread(self.pop_ref(), thread_state); + ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + Ok(thread) + } + + /// Wraps a Lua function into a new or recycled thread (coroutine). + #[cfg(feature = "async")] + pub(crate) unsafe fn create_recycled_thread(&self, func: &Function) -> Result { + if let Some(index) = (*self.extra.get()).thread_pool.pop() { + let thread_state = ffi::lua_tothread(self.ref_thread(), *index.0); + ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); + + #[cfg(feature = "luau")] + { + // Inherit `LUA_GLOBALSINDEX` from the caller + ffi::lua_xpush(self.state(), thread_state, ffi::LUA_GLOBALSINDEX); + ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); + } + + return Ok(Thread(ValueRef::new(self, index), thread_state)); + } + + self.create_thread(func) + } + + /// Returns the thread to the pool for later use. + #[cfg(feature = "async")] + pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { + let extra = &mut *self.extra.get(); + if extra.thread_pool.len() < extra.thread_pool.capacity() + && let Some(index) = thread.0.index_count.take() + { + extra.thread_pool.push(index); + } + } + + /// Pushes a primitive type value onto the Lua stack. + pub(crate) unsafe fn push_primitive_type(&self) -> bool { + match T::TYPE_ID { + ffi::LUA_TBOOLEAN => { + ffi::lua_pushboolean(self.state(), 0); + } + ffi::LUA_TLIGHTUSERDATA => { + ffi::lua_pushlightuserdata(self.state(), ptr::null_mut()); + } + ffi::LUA_TNUMBER => { + ffi::lua_pushnumber(self.state(), 0.); + } + #[cfg(feature = "luau")] + ffi::LUA_TVECTOR => { + #[cfg(not(feature = "luau-vector4"))] + ffi::lua_pushvector(self.state(), 0., 0., 0.); + #[cfg(feature = "luau-vector4")] + ffi::lua_pushvector(self.state(), 0., 0., 0., 0.); + } + ffi::LUA_TSTRING => { + ffi::lua_pushstring(self.state(), b"\0" as *const u8 as *const _); + } + ffi::LUA_TFUNCTION => { + unsafe extern "C-unwind" fn func(_state: *mut ffi::lua_State) -> c_int { + 0 + } + ffi::lua_pushcfunction(self.state(), func); + } + ffi::LUA_TTHREAD => { + ffi::lua_pushthread(self.state()); + } + #[cfg(feature = "luau")] + ffi::LUA_TBUFFER => { + ffi::lua_newbuffer(self.state(), 0); + } + _ => return false, + } + true + } + + /// Pushes a value that implements `IntoLua` onto the Lua stack. + /// + /// Uses up to 2 stack spaces to push a single value, does not call `checkstack`. + #[allow(clippy::missing_safety_doc)] + #[inline(always)] + pub unsafe fn push(&self, value: impl IntoLua) -> Result<()> { + value.push_into_stack(self) + } + + /// Pops a value that implements [`FromLua`] from the top of the Lua stack. + /// + /// Uses up to 1 stack space, does not call `checkstack`. + #[allow(clippy::missing_safety_doc)] + #[inline(always)] + pub unsafe fn pop(&self) -> Result { + let v = R::from_stack(-1, self)?; + ffi::lua_pop(self.state(), 1); + Ok(v) + } + + /// Pushes a `Value` (by reference) onto the Lua stack. + /// + /// Uses up to 2 stack spaces, does not call `checkstack`. + #[allow(clippy::missing_safety_doc)] + pub unsafe fn push_value(&self, value: &Value) -> Result<()> { + let state = self.state(); + match value { + Value::Nil => ffi::lua_pushnil(state), + Value::Boolean(b) => ffi::lua_pushboolean(state, *b as c_int), + Value::LightUserData(ud) => ffi::lua_pushlightuserdata(state, ud.0), + Value::Integer(i) => ffi::lua_pushinteger(state, *i), + Value::Number(n) => ffi::lua_pushnumber(state, *n), + #[cfg(feature = "luau")] + Value::Vector(v) => { + #[cfg(not(feature = "luau-vector4"))] + ffi::lua_pushvector(state, v.x(), v.y(), v.z()); + #[cfg(feature = "luau-vector4")] + ffi::lua_pushvector(state, v.x(), v.y(), v.z(), v.w()); + } + Value::String(s) => self.push_ref(&s.0), + Value::Table(t) => self.push_ref(&t.0), + Value::Function(f) => self.push_ref(&f.0), + Value::Thread(t) => self.push_ref(&t.0), + Value::UserData(ud) => self.push_ref(&ud.0), + #[cfg(feature = "luau")] + Value::Buffer(buf) => self.push_ref(&buf.0), + Value::Error(err) => { + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, WrappedFailure::Error(*err.clone()), protect)?; + } + Value::Other(vref) => self.push_ref(vref), + } + Ok(()) + } + + /// Pops a value from the Lua stack. + /// + /// Uses up to 1 stack spaces, does not call `checkstack`. + #[allow(clippy::missing_safety_doc)] + #[inline] + pub unsafe fn pop_value(&self) -> Value { + let value = self.stack_value(-1, None); + ffi::lua_pop(self.state(), 1); + value + } + + /// Returns value at given stack index without popping it. + /// + /// Uses up to 1 stack spaces, does not call `checkstack`. + pub(crate) unsafe fn stack_value(&self, idx: c_int, type_hint: Option) -> Value { + let state = self.state(); + match type_hint.unwrap_or_else(|| ffi::lua_type(state, idx)) { + ffi::LUA_TNIL => Nil, + + ffi::LUA_TBOOLEAN => Value::Boolean(ffi::lua_toboolean(state, idx) != 0), + + ffi::LUA_TLIGHTUSERDATA => Value::LightUserData(LightUserData(ffi::lua_touserdata(state, idx))), + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + ffi::LUA_TNUMBER => { + if ffi::lua_isinteger(state, idx) != 0 { + Value::Integer(ffi::lua_tointeger(state, idx)) + } else { + Value::Number(ffi::lua_tonumber(state, idx)) + } + } + + #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::LUA_TNUMBER => { + let n = ffi::lua_tonumber(state, idx); + match num_traits::cast(n) { + Some(i) if n.to_bits() == (i as crate::types::Number).to_bits() => Value::Integer(i), + _ => Value::Number(n), + } + } + + #[cfg(feature = "luau")] + ffi::LUA_TINTEGER => { + let i = ffi::lua_tointeger64(state, idx, ptr::null_mut()); + match num_traits::cast(i) { + Some(i) => Value::Integer(i), + _ => Value::Number(i as crate::types::Number), + } + } + + #[cfg(feature = "luau")] + ffi::LUA_TVECTOR => { + let v = ffi::lua_tovector(state, idx); + mlua_debug_assert!(!v.is_null(), "vector is null"); + #[cfg(not(feature = "luau-vector4"))] + return Value::Vector(crate::Vector([*v, *v.add(1), *v.add(2)])); + #[cfg(feature = "luau-vector4")] + return Value::Vector(crate::Vector([*v, *v.add(1), *v.add(2), *v.add(3)])); + } + + ffi::LUA_TSTRING => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::String(LuaString(self.pop_ref_thread())) + } + + ffi::LUA_TTABLE => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::Table(Table(self.pop_ref_thread())) + } + + ffi::LUA_TFUNCTION => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::Function(Function(self.pop_ref_thread())) + } + + ffi::LUA_TUSERDATA => { + // If the userdata is `WrappedFailure`, process it as an error or panic. + let failure_mt_ptr = (*self.extra.get()).wrapped_failure_mt_ptr; + match get_internal_userdata::(state, idx, failure_mt_ptr).as_mut() { + Some(WrappedFailure::Error(err)) => Value::Error(Box::new(err.clone())), + Some(WrappedFailure::Panic(panic)) => { + if let Some(panic) = panic.take() { + resume_unwind(panic); + } + // Previously resumed panic? + Value::Nil + } + _ => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::UserData(AnyUserData(self.pop_ref_thread())) + } + } + } + + ffi::LUA_TTHREAD => { + ffi::lua_xpush(state, self.ref_thread(), idx); + let thread_state = ffi::lua_tothread(self.ref_thread(), -1); + Value::Thread(Thread(self.pop_ref_thread(), thread_state)) + } + + #[cfg(feature = "luau")] + ffi::LUA_TBUFFER => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::Buffer(crate::Buffer(self.pop_ref_thread())) + } + + _ => { + ffi::lua_xpush(state, self.ref_thread(), idx); + Value::Other(self.pop_ref_thread()) + } + } + } + + // Pushes a ValueRef value onto the stack, uses 1 stack space, does not call checkstack + #[inline] + pub(crate) fn push_ref(&self, vref: &ValueRef) { + assert!( + self.weak() == &vref.lua, + "Lua instance passed Value created from a different main Lua state" + ); + unsafe { ffi::lua_xpush(self.ref_thread(), self.state(), vref.index) }; + } + + // Pops the topmost element of the stack and stores a reference to it. This pins the object, + // preventing garbage collection until the returned `ValueRef` is dropped. + // + // References are stored on the stack of a specially created auxiliary thread that exists only + // to store reference values. This is much faster than storing these in the registry, and also + // much more flexible and requires less bookkeeping than storing them directly in the currently + // used stack. + #[inline] + pub(crate) unsafe fn pop_ref(&self) -> ValueRef { + ffi::lua_xmove(self.state(), self.ref_thread(), 1); + let index = (*self.extra.get()).ref_stack_pop(); + ValueRef::new(self, index) + } + + // Same as `pop_ref` but assumes the value is already on the reference thread + #[inline] + pub(crate) unsafe fn pop_ref_thread(&self) -> ValueRef { + let index = (*self.extra.get()).ref_stack_pop(); + ValueRef::new(self, index) + } + + pub(crate) unsafe fn drop_ref(&self, vref: &ValueRef) { + let ref_thread = self.ref_thread(); + mlua_debug_assert!( + ffi::lua_gettop(ref_thread) >= vref.index, + "GC finalizer is not allowed in ref_thread" + ); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, vref.index); + (*self.extra.get()).ref_free.push(vref.index); + } + + #[inline] + pub(crate) unsafe fn push_error_traceback(&self) { + let state = self.state(); + #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))] + ffi::lua_xpush(self.ref_thread(), state, ExtraData::ERROR_TRACEBACK_IDX); + // Lua 5.2+ support light C functions that does not require extra allocations + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + ffi::lua_pushcfunction(state, crate::util::error_traceback); + } + + #[inline] + pub(crate) unsafe fn unlikely_memory_error(&self) -> bool { + #[cfg(debug_assertions)] + if cfg!(force_memory_limit) { + return false; + } + + // MemoryInfo is empty in module mode so we cannot predict memory limits + match MemoryState::get(self.state()) { + mem_state if !mem_state.is_null() => (*mem_state).memory_limit() == 0, + _ => (*self.extra.get()).skip_memory_check, // Check the special flag (only for module mode) + } + } + + pub(crate) unsafe fn make_userdata(&self, data: UserDataStorage) -> Result + where + T: UserData + 'static, + { + self.make_userdata_with_metatable(data, || { + // Check if userdata/metatable is already registered + let type_id = TypeId::of::(); + if let Some(&table_id) = (*self.extra.get()).registered_userdata_t.get(&type_id) { + return Ok(table_id); + } + + // Create a new metatable from `UserData` definition + let mut registry = UserDataRegistry::new(self.lua()); + T::register(&mut registry); + + self.create_userdata_metatable(registry.into_raw()) + }) + } + + pub(crate) unsafe fn make_any_userdata(&self, data: UserDataStorage) -> Result + where + T: 'static, + { + self.make_userdata_with_metatable(data, || { + // Check if userdata/metatable is already registered + let type_id = TypeId::of::(); + if let Some(&table_id) = (*self.extra.get()).registered_userdata_t.get(&type_id) { + return Ok(table_id); + } + + // Check if metatable creation is pending or create an empty metatable otherwise + let registry = match (*self.extra.get()).pending_userdata_reg.remove(&type_id) { + Some(registry) => registry, + None => UserDataRegistry::::new(self.lua()).into_raw(), + }; + self.create_userdata_metatable(registry) + }) + } + + unsafe fn make_userdata_with_metatable( + &self, + data: UserDataStorage, + get_metatable_id: impl FnOnce() -> Result, + ) -> Result { + let state = self.state(); + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + // We generate metatable first to make sure it *always* available when userdata pushed + let mt_id = get_metatable_id()?; + let protect = !self.unlikely_memory_error(); + push_userdata(state, data, protect)?; + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, mt_id as _); + ffi::lua_setmetatable(state, -2); + + // Set empty environment for Lua 5.1 + #[cfg(any(feature = "lua51", feature = "luajit"))] + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_newtable(state); + ffi::lua_setuservalue(state, -2); + })?; + } else { + ffi::lua_newtable(state); + ffi::lua_setuservalue(state, -2); + } + + Ok(AnyUserData(self.pop_ref())) + } + + pub(crate) unsafe fn create_userdata_metatable(&self, registry: RawUserDataRegistry) -> Result { + let state = self.state(); + let type_id = registry.type_id; + + self.push_userdata_metatable(registry)?; + + let mt_ptr = ffi::lua_topointer(state, -1); + let id = protect_lua!(state, 1, 0, |state| { + ffi::luaL_ref(state, ffi::LUA_REGISTRYINDEX) + })?; + + if let Some(type_id) = type_id { + (*self.extra.get()).registered_userdata_t.insert(type_id, id); + } + self.register_userdata_metatable(mt_ptr, type_id); + + Ok(id) + } + + pub(crate) unsafe fn push_userdata_metatable(&self, mut registry: RawUserDataRegistry) -> Result<()> { + let state = self.state(); + let mut stack_guard = StackGuard::new(state); + check_stack(state, 13)?; + + // Prepare metatable, add meta methods first and then meta fields + let metatable_nrec = registry.meta_methods.len() + registry.meta_fields.len(); + #[cfg(feature = "async")] + let metatable_nrec = metatable_nrec + registry.async_meta_methods.len(); + push_table(state, 0, metatable_nrec, true)?; + for (k, m) in registry.meta_methods { + self.push(self.create_callback(m)?)?; + rawset_field(state, -2, MetaMethod::validate(&k)?)?; + } + #[cfg(feature = "async")] + for (k, m) in registry.async_meta_methods { + self.push(self.create_async_callback(m)?)?; + rawset_field(state, -2, MetaMethod::validate(&k)?)?; + } + let mut has_name = false; + for (k, v) in registry.meta_fields { + has_name = has_name || k == MetaMethod::Type; + v?.push_into_stack(self)?; + rawset_field(state, -2, MetaMethod::validate(&k)?)?; + } + // Set `__name/__type` if not provided + if !has_name { + let type_name = registry.type_name; + push_string(state, type_name.as_bytes(), !self.unlikely_memory_error())?; + rawset_field(state, -2, MetaMethod::Type.name())?; + } + let metatable_index = ffi::lua_absindex(state, -1); + + let fields_nrec = registry.fields.len(); + if fields_nrec > 0 { + // If `__index` is a table then update it in-place + let index_type = ffi::lua_getfield(state, metatable_index, cstr!("__index")); + match index_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE => { + if index_type == ffi::LUA_TNIL { + // Create a new table + ffi::lua_pop(state, 1); + push_table(state, 0, fields_nrec, true)?; + } + for (k, v) in mem::take(&mut registry.fields) { + v?.push_into_stack(self)?; + rawset_field(state, -2, &k)?; + } + rawset_field(state, metatable_index, "__index")?; + } + _ => { + ffi::lua_pop(state, 1); + // Fields will be converted to functions and added to field getters + } + } + } + + let mut field_getters_index = None; + let field_getters_nrec = registry.field_getters.len() + registry.fields.len(); + if field_getters_nrec > 0 { + push_table(state, 0, field_getters_nrec, true)?; + for (k, m) in registry.field_getters { + self.push(self.create_callback(m)?)?; + rawset_field(state, -2, &k)?; + } + for (k, v) in registry.fields { + unsafe extern "C-unwind" fn return_field(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + 1 + } + v?.push_into_stack(self)?; + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, return_field, 1); + })?; + rawset_field(state, -2, &k)?; + } + field_getters_index = Some(ffi::lua_absindex(state, -1)); + } + + let mut field_setters_index = None; + let field_setters_nrec = registry.field_setters.len(); + if field_setters_nrec > 0 { + push_table(state, 0, field_setters_nrec, true)?; + for (k, m) in registry.field_setters { + self.push(self.create_callback(m)?)?; + rawset_field(state, -2, &k)?; + } + field_setters_index = Some(ffi::lua_absindex(state, -1)); + } + + // Create methods namecall table + #[cfg_attr(not(feature = "luau"), allow(unused_mut))] + let mut methods_map = None; + #[cfg(feature = "luau")] + if registry.enable_namecall { + let map: &mut rustc_hash::FxHashMap<_, crate::types::CallbackPtr> = + methods_map.get_or_insert_default(); + for (k, m) in ®istry.methods { + map.insert(k.as_bytes().to_vec(), &**m); + } + } + + let mut methods_index = None; + let methods_nrec = registry.methods.len(); + #[cfg(feature = "async")] + let methods_nrec = methods_nrec + registry.async_methods.len(); + if methods_nrec > 0 { + // If `__index` is a table then update it in-place + let index_type = ffi::lua_getfield(state, metatable_index, cstr!("__index")); + match index_type { + ffi::LUA_TTABLE => {} // Update the existing table + _ => { + // Create a new table + ffi::lua_pop(state, 1); + push_table(state, 0, methods_nrec, true)?; + } + } + for (k, m) in registry.methods { + self.push(self.create_callback(m)?)?; + rawset_field(state, -2, &k)?; + } + #[cfg(feature = "async")] + for (k, m) in registry.async_methods { + self.push(self.create_async_callback(m)?)?; + rawset_field(state, -2, &k)?; + } + match index_type { + ffi::LUA_TTABLE => { + ffi::lua_pop(state, 1); // All done + } + ffi::LUA_TNIL => { + // Set the new table as `__index` + rawset_field(state, metatable_index, "__index")?; + } + _ => { + methods_index = Some(ffi::lua_absindex(state, -1)); + } + } + } + + ffi::lua_pushcfunction(state, registry.destructor); + rawset_field(state, metatable_index, "__gc")?; + + init_userdata_metatable( + state, + metatable_index, + field_getters_index, + field_setters_index, + methods_index, + methods_map, + )?; + + // Update stack guard to keep metatable after return + stack_guard.keep(1); + + Ok(()) + } + + #[inline(always)] + pub(crate) unsafe fn register_userdata_metatable(&self, mt_ptr: *const c_void, type_id: Option) { + (*self.extra.get()).registered_userdata_mt.insert(mt_ptr, type_id); + } + + #[inline(always)] + pub(crate) unsafe fn deregister_userdata_metatable(&self, mt_ptr: *const c_void) { + (*self.extra.get()).registered_userdata_mt.remove(&mt_ptr); + if (*self.extra.get()).last_checked_userdata_mt.0 == mt_ptr { + (*self.extra.get()).last_checked_userdata_mt = (ptr::null(), None); + } + } + + // Returns `TypeId` for the userdata ref, checking that it's registered and not destructed. + // + // Returns `None` if the userdata is registered but non-static. + #[inline(always)] + pub(crate) fn get_userdata_ref_type_id(&self, vref: &ValueRef) -> Result> { + unsafe { self.get_userdata_type_id_inner(self.ref_thread(), vref.index) } + } + + // Same as `get_userdata_ref_type_id` but assumes the userdata is already on the stack. + pub(crate) unsafe fn get_userdata_type_id( + &self, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result> { + match self.get_userdata_type_id_inner(state, idx) { + Ok(type_id) => Ok(type_id), + Err(Error::UserDataTypeMismatch) if ffi::lua_type(state, idx) != ffi::LUA_TUSERDATA => { + // Report `FromLuaConversionError` instead + let type_name = CStr::from_ptr(ffi::lua_typename(state, ffi::lua_type(state, idx))) + .to_str() + .unwrap_or("unknown"); + let message = format!("expected userdata of type '{}'", short_type_name::()); + Err(Error::from_lua_conversion(type_name, "userdata", message)) + } + Err(err) => Err(err), + } + } + + unsafe fn get_userdata_type_id_inner( + &self, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result> { + let mt_ptr = get_metatable_ptr(state, idx); + if mt_ptr.is_null() { + return Err(Error::UserDataTypeMismatch); + } + + // Fast path to skip looking up the metatable in the map + let (last_mt, last_type_id) = (*self.extra.get()).last_checked_userdata_mt; + if last_mt == mt_ptr { + return Ok(last_type_id); + } + + match (*self.extra.get()).registered_userdata_mt.get(&mt_ptr) { + Some(&type_id) if type_id == Some(TypeId::of::()) => { + Err(Error::UserDataDestructed) + } + Some(&type_id) => { + (*self.extra.get()).last_checked_userdata_mt = (mt_ptr, type_id); + Ok(type_id) + } + None => Err(Error::UserDataTypeMismatch), + } + } + + // Pushes a ValueRef (userdata) value onto the stack, returning their `TypeId`. + // Uses 1 stack space, does not call checkstack. + pub(crate) unsafe fn push_userdata_ref(&self, vref: &ValueRef) -> Result> { + let type_id = self.get_userdata_type_id_inner(self.ref_thread(), vref.index)?; + self.push_ref(vref); + Ok(type_id) + } + + // Creates a Function out of a Callback containing a 'static Fn. + pub(crate) fn create_callback(&self, func: Callback) -> Result { + unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int { + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + match (*upvalue).data { + Some(ref func) => func(rawlua, nargs), + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let func = Some(func); + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + push_internal_userdata(state, CallbackUpvalue { data: func, extra }, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, call_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, call_callback, 1); + } + + Ok(Function(self.pop_ref())) + } + } + + #[cfg(feature = "async")] + pub(crate) fn create_async_callback(&self, func: AsyncCallback) -> Result { + // Ensure that the coroutine library is loaded + #[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luau" + ))] + unsafe { + if !(*self.extra.get()).libs.contains(StdLib::COROUTINE) { + load_std_libs(self.main_state(), StdLib::COROUTINE)?; + (*self.extra.get()).libs |= StdLib::COROUTINE; + } + } + + unsafe extern "C-unwind" fn get_future_callback(state: *mut ffi::lua_State) -> c_int { + // Async functions cannot be scoped and therefore destroyed, + // so the first upvalue is always valid + let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); + callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the callback is executed + let rawlua = (*extra).raw_lua(); + + let func = &*(*upvalue).data; + let fut = Some(func(rawlua, nargs)); + let extra = XRc::clone(&(*upvalue).extra); + let protect = !rawlua.unlikely_memory_error(); + push_internal_userdata(state, AsyncPollUpvalue { data: fut, extra }, protect)?; + + Ok(1) + }) + } + + unsafe extern "C-unwind" fn poll_future(state: *mut ffi::lua_State) -> c_int { + // Future is always passed in the first argument + let future = get_userdata::(state, 1); + callback_error_ext(state, (*future).extra.get(), true, |extra, nargs| { + // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) + // The lock must be already held as the future is polled + let rawlua = (*extra).raw_lua(); + + if nargs == 2 && ffi::lua_tolightuserdata(state, -1) == Lua::poll_terminate().0 { + // Destroy the future and terminate the Lua thread + (*future).data.take(); + ffi::lua_pushinteger(state, -1); + return Ok(1); + } + + let fut = &mut (*future).data; + let mut ctx = Context::from_waker(rawlua.waker()); + match fut.as_mut().map(|fut| fut.as_mut().poll(&mut ctx)) { + Some(Poll::Pending) => { + let fut_nvals = ffi::lua_gettop(state) - 1; // Exclude the future itself + if fut_nvals >= 3 && ffi::lua_tolightuserdata(state, -3) == Lua::poll_yield().0 { + // We have some values to yield + ffi::lua_pushnil(state); + ffi::lua_replace(state, -4); + return Ok(3); + } + ffi::lua_pushnil(state); + ffi::lua_pushlightuserdata(state, Lua::poll_pending().0); + Ok(2) + } + Some(Poll::Ready(nresults)) => { + match nresults? { + nresults if nresults < 3 => { + // Fast path for up to 2 results without creating a table + ffi::lua_pushinteger(state, nresults as _); + if nresults > 0 { + ffi::lua_insert(state, -nresults - 1); + } + Ok(nresults + 1) + } + nresults => { + let results = MultiValue::from_stack_multi(nresults, rawlua)?; + ffi::lua_pushinteger(state, nresults as _); + rawlua.push(rawlua.create_sequence_from(results)?)?; + Ok(2) + } + } + } + None => Err(Error::CallbackDestructed), + } + }) + } + + let state = self.state(); + let get_future = unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + let extra = XRc::clone(&self.extra); + let protect = !self.unlikely_memory_error(); + let upvalue = AsyncCallbackUpvalue { data: func, extra }; + push_internal_userdata(state, upvalue, protect)?; + if protect { + protect_lua!(state, 1, 1, fn(state) { + ffi::lua_pushcclosure(state, get_future_callback, 1); + })?; + } else { + ffi::lua_pushcclosure(state, get_future_callback, 1); + } + + Function(self.pop_ref()) + }; + + unsafe extern "C-unwind" fn unpack(state: *mut ffi::lua_State) -> c_int { + let len = ffi::lua_tointeger(state, 2); + ffi::luaL_checkstack(state, len as c_int, ptr::null()); + for i in 1..=len { + ffi::lua_rawgeti(state, 1, i); + } + len as c_int + } + + let lua = self.lua(); + let coroutine = lua.globals().get::
("coroutine")?; + + // Prepare environment for the async poller + let env = lua.create_table_with_capacity(0, 4)?; + env.set("get_future", get_future)?; + env.set("poll", unsafe { lua.create_c_function(poll_future)? })?; + env.set("yield", coroutine.get::("yield")?)?; + env.set("unpack", unsafe { lua.create_c_function(unpack)? })?; + + lua.load( + r#" + local poll, yield = poll, yield + local future = get_future(...) + local nres, res, res2 = poll(future) + while true do + -- Poll::Ready branch, `nres` is the number of results + if nres ~= nil then + if nres == 0 then + return + elseif nres == 1 then + return res + elseif nres == 2 then + return res, res2 + elseif nres < 0 then + -- Negative `nres` means that the future is terminated + -- It must stay yielded and never be resumed again + yield() + else + return unpack(res, nres) + end + end + + -- Poll::Pending branch + if res2 == nil then + -- `res` is a "pending" value + -- `yield` can return a signal to drop the future that we should propagate + -- to the poller + nres, res, res2 = poll(future, yield(res)) + elseif res2 == 0 then + nres, res, res2 = poll(future, yield()) + elseif res2 == 1 then + nres, res, res2 = poll(future, yield(res)) + else + nres, res, res2 = poll(future, yield(unpack(res, res2))) + end + end + "#, + ) + .try_cache() + .set_name("=__mlua_async_poll") + .set_environment(env) + .into_function() + } + + #[cfg(feature = "async")] + #[inline] + pub(crate) fn waker(&self) -> &Waker { + unsafe { (*self.extra.get()).waker.as_ref() } + } + + #[cfg(feature = "async")] + #[inline] + pub(crate) fn set_waker(&self, waker: NonNull) -> NonNull { + unsafe { mem::replace(&mut (*self.extra.get()).waker, waker) } + } +} + +// Uses 3 stack spaces +unsafe fn load_std_libs(state: *mut ffi::lua_State, libs: StdLib) -> Result<()> { + unsafe fn requiref( + state: *mut ffi::lua_State, + modname: *const c_char, + openf: ffi::lua_CFunction, + glb: c_int, + ) -> Result<()> { + protect_lua!(state, 0, 0, |state| { + ffi::luaL_requiref(state, modname, openf, glb) + }) + } + + #[cfg(feature = "luajit")] + struct GcGuard(*mut ffi::lua_State); + + #[cfg(feature = "luajit")] + impl GcGuard { + fn new(state: *mut ffi::lua_State) -> Self { + // Stop collector during library initialization + unsafe { ffi::lua_gc(state, ffi::LUA_GCSTOP, 0) }; + GcGuard(state) + } + } + + #[cfg(feature = "luajit")] + impl Drop for GcGuard { + fn drop(&mut self) { + unsafe { ffi::lua_gc(self.0, ffi::LUA_GCRESTART, -1) }; + } + } + + // Stop collector during library initialization + #[cfg(feature = "luajit")] + let _gc_guard = GcGuard::new(state); + + #[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luau" + ))] + { + if libs.contains(StdLib::COROUTINE) { + requiref(state, ffi::LUA_COLIBNAME, ffi::luaopen_coroutine, 1)?; + } + } + + if libs.contains(StdLib::TABLE) { + requiref(state, ffi::LUA_TABLIBNAME, ffi::luaopen_table, 1)?; + } + + #[cfg(not(feature = "luau"))] + if libs.contains(StdLib::IO) { + requiref(state, ffi::LUA_IOLIBNAME, ffi::luaopen_io, 1)?; + } + + if libs.contains(StdLib::OS) { + requiref(state, ffi::LUA_OSLIBNAME, ffi::luaopen_os, 1)?; + } + + if libs.contains(StdLib::STRING) { + requiref(state, ffi::LUA_STRLIBNAME, ffi::luaopen_string, 1)?; + } + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] + { + if libs.contains(StdLib::UTF8) { + requiref(state, ffi::LUA_UTF8LIBNAME, ffi::luaopen_utf8, 1)?; + } + } + + #[cfg(any(feature = "lua52", feature = "luau"))] + { + if libs.contains(StdLib::BIT) { + requiref(state, ffi::LUA_BITLIBNAME, ffi::luaopen_bit32, 1)?; + } + } + + #[cfg(feature = "luajit")] + { + if libs.contains(StdLib::BIT) { + requiref(state, ffi::LUA_BITLIBNAME, ffi::luaopen_bit, 1)?; + } + } + + #[cfg(feature = "luau")] + if libs.contains(StdLib::BUFFER) { + requiref(state, ffi::LUA_BUFFERLIBNAME, ffi::luaopen_buffer, 1)?; + } + + #[cfg(feature = "luau")] + if libs.contains(StdLib::VECTOR) { + requiref(state, ffi::LUA_VECLIBNAME, ffi::luaopen_vector, 1)?; + } + + #[cfg(feature = "luau")] + if libs.contains(StdLib::INTEGER) { + requiref(state, ffi::LUA_INTLIBNAME, ffi::luaopen_integer, 1)?; + } + + if libs.contains(StdLib::MATH) { + requiref(state, ffi::LUA_MATHLIBNAME, ffi::luaopen_math, 1)?; + } + + if libs.contains(StdLib::DEBUG) { + requiref(state, ffi::LUA_DBLIBNAME, ffi::luaopen_debug, 1)?; + } + + #[cfg(not(feature = "luau"))] + if libs.contains(StdLib::PACKAGE) { + requiref(state, ffi::LUA_LOADLIBNAME, ffi::luaopen_package, 1)?; + } + + #[cfg(feature = "luajit")] + if libs.contains(StdLib::JIT) { + requiref(state, ffi::LUA_JITLIBNAME, ffi::luaopen_jit, 1)?; + } + + #[cfg(feature = "luajit")] + if libs.contains(StdLib::FFI) { + requiref(state, ffi::LUA_FFILIBNAME, ffi::luaopen_ffi, 1)?; + } + + Ok(()) +} diff --git a/src/state/util.rs b/src/state/util.rs new file mode 100644 index 00000000..fdbc5cf4 --- /dev/null +++ b/src/state/util.rs @@ -0,0 +1,152 @@ +use std::os::raw::c_int; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::ptr; +use std::sync::Arc; + +use crate::error::{Error, Result}; +use crate::state::{ExtraData, RawLua}; +use crate::util::{self, WrappedFailure, get_internal_metatable}; + +struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); + +impl<'a> StateGuard<'a> { + fn new(inner: &'a RawLua, mut state: *mut ffi::lua_State) -> Self { + state = inner.state.replace(state); + Self(inner, state) + } +} + +impl Drop for StateGuard<'_> { + fn drop(&mut self) { + self.0.state.set(self.1); + } +} + +// An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata +// and instead reuses unused values from previous calls (or allocates new). +pub(crate) unsafe fn callback_error_ext( + state: *mut ffi::lua_State, + mut extra: *mut ExtraData, + wrap_error: bool, + f: F, +) -> R +where + F: FnOnce(*mut ExtraData, c_int) -> Result, +{ + if extra.is_null() { + extra = ExtraData::get(state); + } + + let nargs = ffi::lua_gettop(state); + + enum PreallocatedFailure { + New(*mut WrappedFailure), + Reserved, + } + + impl PreallocatedFailure { + unsafe fn reserve(state: *mut ffi::lua_State, extra: *mut ExtraData) -> Self { + if (*extra).wrapped_failure_top > 0 { + (*extra).wrapped_failure_top -= 1; + return PreallocatedFailure::Reserved; + } + + // We need to check stack for Luau in case when callback is called from interrupt + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + // Place it to the beginning of the stack + let ud = WrappedFailure::new_userdata(state); + ffi::lua_insert(state, 1); + PreallocatedFailure::New(ud) + } + + #[cold] + unsafe fn r#use(&self, state: *mut ffi::lua_State, extra: *mut ExtraData) -> *mut WrappedFailure { + let ref_thread = (*extra).ref_thread; + match *self { + PreallocatedFailure::New(ud) => { + ffi::lua_settop(state, 1); + ud + } + PreallocatedFailure::Reserved => { + let index = (*extra).wrapped_failure_pool.pop().unwrap(); + ffi::lua_settop(state, 0); + #[cfg(feature = "luau")] + ffi::lua_rawcheckstack(state, 2); + ffi::lua_xpush(ref_thread, state, index); + ffi::lua_pushnil(ref_thread); + ffi::lua_replace(ref_thread, index); + (*extra).ref_free.push(index); + ffi::lua_touserdata(state, -1) as *mut WrappedFailure + } + } + } + + unsafe fn release(self, state: *mut ffi::lua_State, extra: *mut ExtraData) { + let ref_thread = (*extra).ref_thread; + match self { + PreallocatedFailure::New(_) => { + ffi::lua_rotate(state, 1, -1); + ffi::lua_xmove(state, ref_thread, 1); + let index = (*extra).ref_stack_pop(); + (*extra).wrapped_failure_pool.push(index); + (*extra).wrapped_failure_top += 1; + } + PreallocatedFailure::Reserved => (*extra).wrapped_failure_top += 1, + } + } + } + + // We cannot shadow Rust errors with Lua ones, so we need to reserve pre-allocated memory + // to store a wrapped failure (error or panic) *before* we proceed. + let prealloc_failure = PreallocatedFailure::reserve(state, extra); + + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { + Ok(Ok(r)) => { + // Return unused `WrappedFailure` to the pool + prealloc_failure.release(state, extra); + r + } + Ok(Err(err)) => { + let wrapped_error = prealloc_failure.r#use(state, extra); + + if !wrap_error { + ptr::write(wrapped_error, WrappedFailure::Error(err)); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = util::to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + ptr::write( + wrapped_error, + WrappedFailure::Error(Error::CallbackError { traceback, cause }), + ); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + + ffi::lua_error(state) + } + Err(p) => { + let wrapped_panic = prealloc_failure.r#use(state, extra); + ptr::write(wrapped_panic, WrappedFailure::Panic(Some(p))); + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + ffi::lua_error(state) + } + } +} diff --git a/src/stdlib.rs b/src/stdlib.rs index 34c7629c..e1c32e9a 100644 --- a/src/stdlib.rs +++ b/src/stdlib.rs @@ -1,5 +1,4 @@ -use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign}; -use std::u32; +use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not}; /// Flags describing the set of lua standard libraries to load. #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] @@ -7,54 +6,88 @@ pub struct StdLib(u32); impl StdLib { /// [`coroutine`](https://www.lua.org/manual/5.4/manual.html#6.2) library - /// - /// Requires `feature = "lua54/lua53/lua52/luau"` #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "luau" ))] + #[cfg_attr( + docsrs, + doc(cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luau" + ))) + )] pub const COROUTINE: StdLib = StdLib(1); + /// [`table`](https://www.lua.org/manual/5.4/manual.html#6.6) library pub const TABLE: StdLib = StdLib(1 << 1); + /// [`io`](https://www.lua.org/manual/5.4/manual.html#6.8) library #[cfg(not(feature = "luau"))] #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub const IO: StdLib = StdLib(1 << 2); + /// [`os`](https://www.lua.org/manual/5.4/manual.html#6.9) library pub const OS: StdLib = StdLib(1 << 3); + /// [`string`](https://www.lua.org/manual/5.4/manual.html#6.4) library pub const STRING: StdLib = StdLib(1 << 4); + /// [`utf8`](https://www.lua.org/manual/5.4/manual.html#6.5) library - /// - /// Requires `feature = "lua54/lua53/luau"` - #[cfg(any(feature = "lua54", feature = "lua53", feature = "luau"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))) + )] pub const UTF8: StdLib = StdLib(1 << 5); + /// [`bit`](https://www.lua.org/manual/5.2/manual.html#6.7) library - /// - /// Requires `feature = "lua52/luajit/luau"` #[cfg(any(feature = "lua52", feature = "luajit", feature = "luau", doc))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua52", feature = "luajit", feature = "luau"))) + )] pub const BIT: StdLib = StdLib(1 << 6); + /// [`math`](https://www.lua.org/manual/5.4/manual.html#6.7) library pub const MATH: StdLib = StdLib(1 << 7); + /// [`package`](https://www.lua.org/manual/5.4/manual.html#6.3) library #[cfg(not(feature = "luau"))] #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub const PACKAGE: StdLib = StdLib(1 << 8); + + /// [`buffer`](https://luau.org/library#buffer-library) library + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub const BUFFER: StdLib = StdLib(1 << 9); + + /// [`vector`](https://luau.org/library#vector-library) library + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub const VECTOR: StdLib = StdLib(1 << 10); + + /// [`integer`](https://luau.org/library#integer-library) library + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub const INTEGER: StdLib = StdLib(1 << 11); + /// [`jit`](http://luajit.org/ext_jit.html) library - /// - /// Requires `feature = "luajit"` #[cfg(any(feature = "luajit", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luajit")))] - pub const JIT: StdLib = StdLib(1 << 9); + pub const JIT: StdLib = StdLib(1 << 12); /// (**unsafe**) [`ffi`](http://luajit.org/ext_ffi.html) library - /// - /// Requires `feature = "luajit"` #[cfg(any(feature = "luajit", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luajit")))] pub const FFI: StdLib = StdLib(1 << 30); + /// (**unsafe**) [`debug`](https://www.lua.org/manual/5.4/manual.html#6.10) library pub const DEBUG: StdLib = StdLib(1 << 31); @@ -111,3 +144,10 @@ impl BitXorAssign for StdLib { *self = StdLib(self.0 ^ rhs.0) } } + +impl Not for StdLib { + type Output = Self; + fn not(self) -> Self::Output { + StdLib(!self.0) + } +} diff --git a/src/string.rs b/src/string.rs index 87234f09..63ceaf3f 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,57 +1,67 @@ -use std::borrow::{Borrow, Cow}; +//! Lua string handling. +//! +//! This module provides types for working with Lua strings from Rust. + +use std::borrow::Borrow; use std::hash::{Hash, Hasher}; -use std::string::String as StdString; -use std::{slice, str}; +use std::ops::Deref; +use std::os::raw::{c_int, c_void}; +use std::{cmp, fmt, mem, slice, str}; + +use crate::error::{Error, Result}; +use crate::state::Lua; +use crate::traits::IntoLua; +use crate::types::{LuaType, ValueRef}; +use crate::value::Value; -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] use { serde::ser::{Serialize, Serializer}, std::result::Result as StdResult, }; -use crate::error::{Error, Result}; -use crate::ffi; -use crate::types::LuaRef; -use crate::util::{assert_stack, StackGuard}; - /// Handle to an internal Lua string. /// /// Unlike Rust strings, Lua strings may not be valid UTF-8. -#[derive(Clone, Debug)] -pub struct String<'lua>(pub(crate) LuaRef<'lua>); +#[derive(Clone, PartialEq)] +pub struct LuaString(pub(crate) ValueRef); -impl<'lua> String<'lua> { - /// Get a `&str` slice if the Lua string is valid UTF-8. +impl LuaString { + /// Get a [`BorrowedStr`] if the Lua string is valid UTF-8. + /// + /// The returned `BorrowedStr` holds a strong reference to the Lua state to guarantee the + /// validity of the underlying data. /// /// # Examples /// /// ``` - /// # use mlua::{Lua, Result, String}; + /// # use mlua::{Lua, LuaString, Result}; /// # fn main() -> Result<()> { /// # let lua = Lua::new(); /// let globals = lua.globals(); /// - /// let version: String = globals.get("_VERSION")?; + /// let version: LuaString = globals.get("_VERSION")?; /// assert!(version.to_str()?.contains("Lua")); /// - /// let non_utf8: String = lua.load(r#" "test\255" "#).eval()?; + /// let non_utf8: LuaString = lua.load(r#" "test\255" "#).eval()?; /// assert!(non_utf8.to_str().is_err()); /// # Ok(()) /// # } /// ``` - pub fn to_str(&self) -> Result<&str> { - str::from_utf8(self.as_bytes()).map_err(|e| Error::FromLuaConversionError { - from: "string", - to: "&str", - message: Some(e.to_string()), - }) + #[inline] + pub fn to_str(&self) -> Result { + BorrowedStr::try_from(self) } - /// Converts this string to a [`Cow`]. + /// Converts this Lua string to a [`String`]. /// /// Any non-Unicode sequences are replaced with [`U+FFFD REPLACEMENT CHARACTER`][U+FFFD]. /// + /// This method returns [`String`] instead of [`Cow<'_, str>`] because lifetime cannot be + /// bound to a weak Lua object. + /// /// [U+FFFD]: std::char::REPLACEMENT_CHARACTER + /// [`Cow<'_, str>`]: std::borrow::Cow /// /// # Examples /// @@ -65,101 +75,363 @@ impl<'lua> String<'lua> { /// # Ok(()) /// # } /// ``` - pub fn to_string_lossy(&self) -> Cow<'_, str> { - StdString::from_utf8_lossy(self.as_bytes()) + #[inline] + pub fn to_string_lossy(&self) -> String { + String::from_utf8_lossy(&self.as_bytes()).into_owned() + } + + /// Returns an object that implements [`Display`] for safely printing a [`LuaString`] that may + /// contain non-Unicode data. + /// + /// This may perform lossy conversion. + /// + /// [`Display`]: fmt::Display + pub fn display(&self) -> impl fmt::Display + '_ { + Display(self) } /// Get the bytes that make up this string. /// - /// The returned slice will not contain the terminating nul byte, but will contain any nul - /// bytes embedded into the Lua string. + /// The returned `BorrowedStr` holds a strong reference to the Lua state to guarantee the + /// validity of the underlying data. The data will not contain the terminating null byte, but + /// will contain any null bytes embedded into the Lua string. /// /// # Examples /// /// ``` - /// # use mlua::{Lua, Result, String}; + /// # use mlua::{Lua, LuaString, Result}; /// # fn main() -> Result<()> { /// # let lua = Lua::new(); - /// let non_utf8: String = lua.load(r#" "test\255" "#).eval()?; + /// let non_utf8: LuaString = lua.load(r#" "test\255" "#).eval()?; /// assert!(non_utf8.to_str().is_err()); // oh no :( /// assert_eq!(non_utf8.as_bytes(), &b"test\xff"[..]); /// # Ok(()) /// # } /// ``` - pub fn as_bytes(&self) -> &[u8] { - let nulled = self.as_bytes_with_nul(); - &nulled[..nulled.len() - 1] + #[inline] + pub fn as_bytes(&self) -> BorrowedBytes { + BorrowedBytes::from(self) + } + + /// Get the bytes that make up this string, including the trailing null byte. + pub fn as_bytes_with_nul(&self) -> BorrowedBytes { + let BorrowedBytes { buf, vref, _lua } = BorrowedBytes::from(self); + // Include the trailing null byte (it's always present but excluded by default) + let buf = unsafe { slice::from_raw_parts((*buf).as_ptr(), (*buf).len() + 1) }; + BorrowedBytes { buf, vref, _lua } } - /// Get the bytes that make up this string, including the trailing nul byte. - pub fn as_bytes_with_nul(&self) -> &[u8] { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 1); + // Does not return the terminating null byte + unsafe fn to_slice(&self) -> (&[u8], Lua) { + let lua = self.0.lua.upgrade(); + let slice = { + let rawlua = lua.lock(); + let ref_thread = rawlua.ref_thread(); - lua.push_ref(&self.0); mlua_debug_assert!( - ffi::lua_type(lua.state, -1) == ffi::LUA_TSTRING, + ffi::lua_type(ref_thread, self.0.index) == ffi::LUA_TSTRING, "string ref is not string type" ); - let mut size = 0; // This will not trigger a 'm' error, because the reference is guaranteed to be of // string type - let data = ffi::lua_tolstring(lua.state, -1, &mut size); - - slice::from_raw_parts(data as *const u8, size + 1) - } + let mut size = 0; + let data = ffi::lua_tolstring(ref_thread, self.0.index, &mut size); + slice::from_raw_parts(data as *const u8, size) + }; + (slice, lua) } -} -impl<'lua> AsRef<[u8]> for String<'lua> { - fn as_ref(&self) -> &[u8] { - self.as_bytes() + /// Converts this Lua string to a generic C pointer. + /// + /// There is no way to convert the pointer back to its original value. + /// + /// Typically this function is used only for hashing and debug information. + #[inline] + pub fn to_pointer(&self) -> *const c_void { + // In Lua < 5.4 (excluding Luau), string pointers are NULL + // Use alternative approach + let lua = self.0.lua.lock(); + unsafe { ffi::lua_tostring(lua.ref_thread(), self.0.index) as *const c_void } } } -impl<'lua> Borrow<[u8]> for String<'lua> { - fn borrow(&self) -> &[u8] { - self.as_bytes() +impl fmt::Debug for LuaString { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.as_bytes(); + // Check if the string is valid utf8 + if let Ok(s) = str::from_utf8(&bytes) { + return s.fmt(f); + } + + // Format as bytes + write!(f, "b")?; + ::fmt(bstr::BStr::new(&bytes), f) } } -// Lua strings are basically &[u8] slices, so implement PartialEq for anything resembling that. +// Lua strings are basically `&[u8]` slices, so implement `PartialEq` for anything resembling that. // -// This makes our `String` comparable with `Vec`, `[u8]`, `&str`, `String` and `mlua::String` -// itself. +// This makes our `LuaString` comparable with `Vec`, `[u8]`, `&str` and `String`. // // The only downside is that this disallows a comparison with `Cow`, as that only implements // `AsRef`, which collides with this impl. Requiring `AsRef` would fix that, but limit us // in other ways. -impl<'lua, T> PartialEq for String<'lua> +impl PartialEq for LuaString where - T: AsRef<[u8]>, + T: AsRef<[u8]> + ?Sized, { fn eq(&self, other: &T) -> bool { self.as_bytes() == other.as_ref() } } -impl<'lua> Eq for String<'lua> {} +impl Eq for LuaString {} + +impl PartialOrd for LuaString +where + T: AsRef<[u8]> + ?Sized, +{ + fn partial_cmp(&self, other: &T) -> Option { + <[u8]>::partial_cmp(&self.as_bytes(), other.as_ref()) + } +} + +impl PartialOrd for LuaString { + fn partial_cmp(&self, other: &LuaString) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for LuaString { + fn cmp(&self, other: &LuaString) -> cmp::Ordering { + self.as_bytes().cmp(&other.as_bytes()) + } +} -impl<'lua> Hash for String<'lua> { +impl Hash for LuaString { fn hash(&self, state: &mut H) { self.as_bytes().hash(state); } } -#[cfg(feature = "serialize")] -impl<'lua> Serialize for String<'lua> { +#[cfg(feature = "serde")] +impl Serialize for LuaString { fn serialize(&self, serializer: S) -> StdResult where S: Serializer, { match self.to_str() { - Ok(s) => serializer.serialize_str(s), - Err(_) => serializer.serialize_bytes(self.as_bytes()), + Ok(s) => serializer.serialize_str(&s), + Err(_) => serializer.serialize_bytes(&self.as_bytes()), } } } + +struct Display<'a>(&'a LuaString); + +impl fmt::Display for Display<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.0.as_bytes(); + ::fmt(bstr::BStr::new(&bytes), f) + } +} + +/// A borrowed string (`&str`) that holds a strong reference to the Lua state. +pub struct BorrowedStr { + // `buf` points to a readonly memory managed by Lua + pub(crate) buf: &'static str, + pub(crate) vref: ValueRef, + pub(crate) _lua: Lua, +} + +impl Deref for BorrowedStr { + type Target = str; + + #[inline(always)] + fn deref(&self) -> &str { + self.buf + } +} + +impl Borrow for BorrowedStr { + #[inline(always)] + fn borrow(&self) -> &str { + self.buf + } +} + +impl AsRef for BorrowedStr { + #[inline(always)] + fn as_ref(&self) -> &str { + self.buf + } +} + +impl fmt::Display for BorrowedStr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.buf.fmt(f) + } +} + +impl fmt::Debug for BorrowedStr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.buf.fmt(f) + } +} + +impl PartialEq for BorrowedStr +where + T: AsRef, +{ + fn eq(&self, other: &T) -> bool { + self.buf == other.as_ref() + } +} + +impl Eq for BorrowedStr {} + +impl PartialOrd for BorrowedStr +where + T: AsRef, +{ + fn partial_cmp(&self, other: &T) -> Option { + self.buf.partial_cmp(other.as_ref()) + } +} + +impl Ord for BorrowedStr { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.buf.cmp(other.buf) + } +} + +impl TryFrom<&LuaString> for BorrowedStr { + type Error = Error; + + #[inline] + fn try_from(value: &LuaString) -> Result { + let BorrowedBytes { buf, vref, _lua } = BorrowedBytes::from(value); + let buf = + str::from_utf8(buf).map_err(|e| Error::from_lua_conversion("string", "&str", e.to_string()))?; + Ok(Self { buf, vref, _lua }) + } +} + +/// A borrowed byte slice (`&[u8]`) that holds a strong reference to the Lua state. +pub struct BorrowedBytes { + // `buf` points to a readonly memory managed by Lua + pub(crate) buf: &'static [u8], + pub(crate) vref: ValueRef, + pub(crate) _lua: Lua, +} + +impl Deref for BorrowedBytes { + type Target = [u8]; + + #[inline(always)] + fn deref(&self) -> &[u8] { + self.buf + } +} + +impl Borrow<[u8]> for BorrowedBytes { + #[inline(always)] + fn borrow(&self) -> &[u8] { + self.buf + } +} + +impl AsRef<[u8]> for BorrowedBytes { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.buf + } +} + +impl fmt::Debug for BorrowedBytes { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.buf.fmt(f) + } +} + +impl PartialEq for BorrowedBytes +where + T: AsRef<[u8]>, +{ + fn eq(&self, other: &T) -> bool { + self.buf == other.as_ref() + } +} + +impl Eq for BorrowedBytes {} + +impl PartialOrd for BorrowedBytes +where + T: AsRef<[u8]>, +{ + fn partial_cmp(&self, other: &T) -> Option { + self.buf.partial_cmp(other.as_ref()) + } +} + +impl Ord for BorrowedBytes { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.buf.cmp(other.buf) + } +} + +impl<'a> IntoIterator for &'a BorrowedBytes { + type Item = &'a u8; + type IntoIter = slice::Iter<'a, u8>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl From<&LuaString> for BorrowedBytes { + #[inline] + fn from(value: &LuaString) -> Self { + let (buf, _lua) = unsafe { value.to_slice() }; + let vref = value.0.clone(); + // SAFETY: The `buf` is valid for the lifetime of the Lua state and occupied slot index + let buf = unsafe { mem::transmute::<&[u8], &'static [u8]>(buf) }; + Self { buf, vref, _lua } + } +} + +struct WrappedString>(T); + +impl LuaString { + /// Wraps bytes, returning an opaque type that implements [`IntoLua`] trait. + /// + /// This function uses [`Lua::create_string`] under the hood. + pub fn wrap(data: impl AsRef<[u8]>) -> impl IntoLua { + WrappedString(data) + } +} + +impl> IntoLua for WrappedString { + fn into_lua(self, lua: &Lua) -> Result { + lua.create_string(self.0).map(Value::String) + } +} + +impl LuaType for LuaString { + const TYPE_ID: c_int = ffi::LUA_TSTRING; +} + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(LuaString: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(LuaString: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(BorrowedBytes: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(BorrowedStr: Send, Sync); +} diff --git a/src/table.rs b/src/table.rs index 05e24eba..e938000e 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,28 +1,184 @@ +//! Lua table handling. +//! +//! Tables are Lua's primary data structure, used for arrays, dictionaries, objects, modules, +//! and more. This module provides types for creating and manipulating Lua tables from Rust. +//! +//! # Basic Operations +//! +//! Tables support key-value access similar to Rust's `HashMap`: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let table = lua.create_table()?; +//! +//! // Set and get values +//! table.set("key", "value")?; +//! let value: String = table.get("key")?; +//! assert_eq!(value, "value"); +//! +//! // Keys and values can be any Lua-compatible type +//! table.set(1, "first")?; +//! table.set("nested", lua.create_table()?)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Array Operations +//! +//! Tables can be used as arrays with 1-based indexing: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let array = lua.create_table()?; +//! +//! // Push values to the end (like Vec::push) +//! array.push("first")?; +//! array.push("second")?; +//! array.push("third")?; +//! +//! // Pop from the end +//! let last: String = array.pop()?; +//! assert_eq!(last, "third"); +//! +//! // Get length +//! assert_eq!(array.raw_len(), 2); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Iteration +//! +//! Iterate over all key-value pairs with [`Table::pairs`]: +//! +//! ``` +//! # use mlua::{Lua, Result, Value}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let table = lua.create_table()?; +//! table.set("a", 1)?; +//! table.set("b", 2)?; +//! +//! for pair in table.pairs::() { +//! let (key, value) = pair?; +//! println!("{key} = {value}"); +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! For array portions, use [`Table::sequence_values`]: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let array = lua.create_sequence_from(["a", "b", "c"])?; +//! +//! for value in array.sequence_values::() { +//! println!("{}", value?); +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! # Raw vs Normal Access +//! +//! Methods prefixed with `raw_` (like [`Table::raw_get`], [`Table::raw_set`]) bypass +//! metamethods, directly accessing the table's contents. Normal methods may trigger +//! `__index`, `__newindex`, and other metamethods: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! +//! // raw_set bypasses __newindex metamethod +//! let t = lua.create_table()?; +//! t.raw_set("key", "value")?; +//! +//! // raw_get bypasses __index metamethod +//! let v: String = t.raw_get("key")?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Metatables +//! +//! Tables can have metatables that customize their behavior: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! +//! let table = lua.create_table()?; +//! let metatable = lua.create_table()?; +//! +//! // Set a default value via __index +//! metatable.set("__index", lua.create_function(|_, _: ()| Ok("default"))?)?; +//! table.set_metatable(Some(metatable))?; +//! +//! // Accessing missing keys returns "default" +//! let value: String = table.get("missing")?; +//! assert_eq!(value, "default"); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Global Table +//! +//! The Lua global environment is itself a table, accessible via [`Lua::globals`]: +//! +//! ``` +//! # use mlua::{Lua, Result}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let globals = lua.globals(); +//! +//! // Set a global variable +//! globals.set("my_var", 42)?; +//! +//! // Now accessible from Lua code +//! let result: i32 = lua.load("my_var + 8").eval()?; +//! assert_eq!(result, 50); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`Lua::globals`]: crate::Lua::globals + +use std::collections::HashSet; +use std::fmt; use std::marker::PhantomData; - -#[cfg(feature = "serialize")] -use { - rustc_hash::FxHashSet, - serde::ser::{self, Serialize, SerializeMap, SerializeSeq, Serializer}, - std::{cell::RefCell, os::raw::c_void, result::Result as StdResult}, -}; +use std::os::raw::c_void; use crate::error::{Error, Result}; -use crate::ffi; use crate::function::Function; -use crate::types::{Integer, LuaRef}; -use crate::util::{assert_stack, check_stack, StackGuard}; -use crate::value::{FromLua, FromLuaMulti, Nil, ToLua, ToLuaMulti, Value}; +use crate::state::{LuaGuard, RawLua, WeakLua}; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, ObjectLike}; +use crate::types::{Integer, ValueRef}; +use crate::util::{StackGuard, assert_stack, check_stack, get_metatable_ptr}; +use crate::value::{Nil, Value}; #[cfg(feature = "async")] -use {futures_core::future::LocalBoxFuture, futures_util::future}; +use crate::function::AsyncCallFuture; + +#[cfg(feature = "serde")] +use { + rustc_hash::FxHashSet, + serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer}, + std::{cell::RefCell, rc::Rc, result::Result as StdResult}, +}; /// Handle to an internal Lua table. -#[derive(Clone, Debug)] -pub struct Table<'lua>(pub(crate) LuaRef<'lua>); +#[derive(Clone, PartialEq)] +pub struct Table(pub(crate) ValueRef); -#[allow(clippy::len_without_is_empty)] -impl<'lua> Table<'lua> { +impl Table { /// Sets a key-value pair in the table. /// /// If the value is `nil`, this will effectively remove the pair. @@ -55,20 +211,27 @@ impl<'lua> Table<'lua> { /// # } /// ``` /// - /// [`raw_set`]: #method.raw_set - pub fn set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { - let lua = self.0.lua; - let key = key.to_lua(lua)?; - let value = value.to_lua(lua)?; + /// [`raw_set`]: Table::raw_set + pub fn set(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()> { + // Fast track (skip protected call) + if !self.has_metatable() { + return self.raw_set(key, value); + } + + self.set_protected(key, value) + } + pub(crate) fn set_protected(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_ref(&self.0); - lua.push_value(key)?; - lua.push_value(value)?; - protect_lua!(lua.state, 3, 0, fn(state) ffi::lua_settable(state, -3)) + key.push_into_stack(&lua)?; + value.push_into_stack(&lua)?; + protect_lua!(state, 3, 0, fn(state) ffi::lua_settable(state, -3)) } } @@ -95,37 +258,86 @@ impl<'lua> Table<'lua> { /// # } /// ``` /// - /// [`raw_get`]: #method.raw_get - pub fn get, V: FromLua<'lua>>(&self, key: K) -> Result { - let lua = self.0.lua; - let key = key.to_lua(lua)?; + /// [`raw_get`]: Table::raw_get + pub fn get(&self, key: impl IntoLua) -> Result { + // Fast track (skip protected call) + if !self.has_metatable() { + return self.raw_get(key); + } - let value = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + self.get_protected(key) + } + + pub(crate) fn get_protected(&self, key: impl IntoLua) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_ref(&self.0); - lua.push_value(key)?; - protect_lua!(lua.state, 2, 1, fn(state) ffi::lua_gettable(state, -2))?; + key.push_into_stack(&lua)?; + protect_lua!(state, 2, 1, fn(state) ffi::lua_gettable(state, -2))?; - lua.pop_value() - }; - V::from_lua(value, lua) + V::from_stack(-1, &lua) + } } /// Checks whether the table contains a non-nil value for `key`. - pub fn contains_key>(&self, key: K) -> Result { - let lua = self.0.lua; - let key = key.to_lua(lua)?; + /// + /// This might invoke the `__index` metamethod. + pub fn contains_key(&self, key: impl IntoLua) -> Result { + Ok(self.get::(key)? != Value::Nil) + } + /// Appends a value to the back of the table. + /// + /// This might invoke the `__len` and `__newindex` metamethods. + pub fn push(&self, value: impl IntoLua) -> Result<()> { + // Fast track (skip protected call) + if !self.has_metatable() { + return self.raw_push(value); + } + + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_ref(&self.0); - lua.push_value(key)?; - protect_lua!(lua.state, 2, 1, fn(state) ffi::lua_gettable(state, -2))?; - Ok(ffi::lua_isnil(lua.state, -1) == 0) + value.push_into_stack(&lua)?; + protect_lua!(state, 2, 0, fn(state) { + let len = ffi::luaL_len(state, -2) as Integer; + ffi::lua_seti(state, -2, len + 1); + })? + } + Ok(()) + } + + /// Removes the last element from the table and returns it. + /// + /// This might invoke the `__len` and `__newindex` metamethods. + pub fn pop(&self) -> Result { + // Fast track (skip protected call) + if !self.has_metatable() { + return self.raw_pop(); + } + + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + lua.push_ref(&self.0); + protect_lua!(state, 1, 1, fn(state) { + let len = ffi::luaL_len(state, -1) as Integer; + ffi::lua_geti(state, -1, len); + ffi::lua_pushnil(state); + ffi::lua_seti(state, -3, len); + })?; + V::from_stack(-1, &lua) } } @@ -151,92 +363,95 @@ impl<'lua> Table<'lua> { /// /// let always_equals_mt = lua.create_table()?; /// always_equals_mt.set("__eq", lua.create_function(|_, (_t1, _t2): (Table, Table)| Ok(true))?)?; - /// table2.set_metatable(Some(always_equals_mt)); + /// table2.set_metatable(Some(always_equals_mt))?; /// /// assert!(table1.equals(&table1.clone())?); /// assert!(table1.equals(&table2)?); /// # Ok(()) /// # } /// ``` - pub fn equals>(&self, other: T) -> Result { - let other = other.as_ref(); + pub fn equals(&self, other: &Self) -> Result { if self == other { return Ok(true); } - // Compare using __eq metamethod if exists + // Compare using `__eq` metamethod if exists // First, check the self for the metamethod. // If self does not define it, then check the other table. - if let Some(mt) = self.get_metatable() { - if mt.contains_key("__eq")? { - return mt - .get::<_, Function>("__eq")? - .call((self.clone(), other.clone())); - } + if let Some(mt) = self.metatable() + && let Some(eq_func) = mt.get::>("__eq")? + { + return eq_func.call((self, other)); } - if let Some(mt) = other.get_metatable() { - if mt.contains_key("__eq")? { - return mt - .get::<_, Function>("__eq")? - .call((self.clone(), other.clone())); - } + if let Some(mt) = other.metatable() + && let Some(eq_func) = mt.get::>("__eq")? + { + return eq_func.call((self, other)); } Ok(false) } /// Sets a key-value pair without invoking metamethods. - pub fn raw_set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { - let lua = self.0.lua; - let key = key.to_lua(lua)?; - let value = value.to_lua(lua)?; - + pub fn raw_set(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; + #[cfg(feature = "luau")] + self.check_readonly_write(&lua)?; + + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_ref(&self.0); - lua.push_value(key)?; - lua.push_value(value)?; - protect_lua!(lua.state, 3, 0, fn(state) ffi::lua_rawset(state, -3)) + key.push_into_stack(&lua)?; + value.push_into_stack(&lua)?; + + if lua.unlikely_memory_error() { + ffi::lua_rawset(state, -3); + ffi::lua_pop(state, 1); + Ok(()) + } else { + protect_lua!(state, 3, 0, fn(state) ffi::lua_rawset(state, -3)) + } } } /// Gets the value associated to `key` without invoking metamethods. - pub fn raw_get, V: FromLua<'lua>>(&self, key: K) -> Result { - let lua = self.0.lua; - let key = key.to_lua(lua)?; - - let value = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 3)?; + pub fn raw_get(&self, key: impl IntoLua) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 3)?; lua.push_ref(&self.0); - lua.push_value(key)?; - ffi::lua_rawget(lua.state, -2); + key.push_into_stack(&lua)?; + ffi::lua_rawget(state, -2); - lua.pop_value() - }; - V::from_lua(value, lua) + V::from_stack(-1, &lua) + } } - /// Inserts element value at position `idx` to the table, shifting up the elements from `table[idx]`. + /// Inserts element value at position `idx` to the table, shifting up the elements from + /// `table[idx]`. + /// /// The worst case complexity is O(n), where n is the table length. - pub fn raw_insert>(&self, idx: Integer, value: V) -> Result<()> { - let lua = self.0.lua; - let size = self.raw_len(); + pub fn raw_insert(&self, idx: Integer, value: impl IntoLua) -> Result<()> { + let size = self.raw_len() as Integer; if idx < 1 || idx > size + 1 { - return Err(Error::RuntimeError("index out of bounds".to_string())); + return Err(Error::runtime("index out of bounds")); } - let value = value.to_lua(lua)?; + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_ref(&self.0); - lua.push_value(value)?; - protect_lua!(lua.state, 2, 0, |state| { + value.push_into_stack(&lua)?; + protect_lua!(state, 2, 0, |state| { for i in (idx..=size).rev() { // table[i+1] = table[i] ffi::lua_rawgeti(state, -2, i); @@ -247,28 +462,79 @@ impl<'lua> Table<'lua> { } } + /// Appends a value to the back of the table without invoking metamethods. + pub fn raw_push(&self, value: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + #[cfg(feature = "luau")] + self.check_readonly_write(&lua)?; + + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + lua.push_ref(&self.0); + value.push_into_stack(&lua)?; + + unsafe fn callback(state: *mut ffi::lua_State) { + let len = ffi::lua_rawlen(state, -2) as Integer; + ffi::lua_rawseti(state, -2, len + 1); + } + + if lua.unlikely_memory_error() { + callback(state); + } else { + protect_lua!(state, 2, 0, fn(state) callback(state))?; + } + } + Ok(()) + } + + /// Removes the last element from the table and returns it, without invoking metamethods. + pub fn raw_pop(&self) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + #[cfg(feature = "luau")] + self.check_readonly_write(&lua)?; + + let _sg = StackGuard::new(state); + check_stack(state, 3)?; + + lua.push_ref(&self.0); + let len = ffi::lua_rawlen(state, -1) as Integer; + ffi::lua_rawgeti(state, -1, len); + // Set slot to nil (it must be safe to do) + ffi::lua_pushnil(state); + ffi::lua_rawseti(state, -3, len); + + V::from_stack(-1, &lua) + } + } + /// Removes a key from the table. /// /// If `key` is an integer, mlua shifts down the elements from `table[key+1]`, - /// and erases element `table[key]`. The complexity is O(n) in the worst case, - /// where n is the table length. + /// and erases element `table[key]`. The complexity is `O(n)` in the worst case, + /// where `n` is the table length. /// /// For other key types this is equivalent to setting `table[key] = nil`. - pub fn raw_remove>(&self, key: K) -> Result<()> { - let lua = self.0.lua; - let key = key.to_lua(lua)?; + pub fn raw_remove(&self, key: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); + let key = key.into_lua(lua.lua())?; match key { Value::Integer(idx) => { - let size = self.raw_len(); + let size = self.raw_len() as Integer; if idx < 1 || idx > size { - return Err(Error::RuntimeError("index out of bounds".to_string())); + return Err(Error::runtime("index out of bounds")); } unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_ref(&self.0); - protect_lua!(lua.state, 1, 0, |state| { + protect_lua!(state, 1, 0, |state| { for i in idx..size { ffi::lua_rawgeti(state, -1, i + 1); ffi::lua_rawseti(state, -2, i); @@ -282,48 +548,96 @@ impl<'lua> Table<'lua> { } } - /// Returns the result of the Lua `#` operator. + /// Clears the table, removing all keys and values from array and hash parts, + /// without invoking metamethods. /// - /// This might invoke the `__len` metamethod. Use the [`raw_len`] method if that is not desired. + /// This method is useful to clear the table while keeping its capacity. + pub fn clear(&self) -> Result<()> { + let lua = self.0.lua.lock(); + unsafe { + #[cfg(feature = "luau")] + { + self.check_readonly_write(&lua)?; + ffi::lua_cleartable(lua.ref_thread(), self.0.index); + } + + #[cfg(not(feature = "luau"))] + { + let state = lua.state(); + check_stack(state, 4)?; + + lua.push_ref(&self.0); + + // This is safe as long as we don't assign new keys + ffi::lua_pushnil(state); + while ffi::lua_next(state, -2) != 0 { + ffi::lua_pop(state, 1); // pop value + ffi::lua_pushvalue(state, -1); // copy key + ffi::lua_pushnil(state); + ffi::lua_rawset(state, -4); + } + } + } + + Ok(()) + } + + /// Returns the result of the Lua `#` operator. /// - /// [`raw_len`]: #method.raw_len + /// This might invoke the `__len` metamethod. Use the [`Table::raw_len`] method if that is not + /// desired. pub fn len(&self) -> Result { - let lua = self.0.lua; + // Fast track (skip protected call) + if !self.has_metatable() { + return Ok(self.raw_len() as Integer); + } + + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_ref(&self.0); - protect_lua!(lua.state, 1, 0, |state| ffi::luaL_len(state, -1)) + protect_lua!(state, 1, 0, |state| ffi::luaL_len(state, -1)) } } /// Returns the result of the Lua `#` operator, without invoking the `__len` metamethod. - pub fn raw_len(&self) -> Integer { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 1); + pub fn raw_len(&self) -> usize { + let lua = self.0.lua.lock(); + unsafe { ffi::lua_rawlen(lua.ref_thread(), self.0.index) } + } - lua.push_ref(&self.0); - ffi::lua_rawlen(lua.state, -1) as Integer + /// Returns `true` if the table is empty, without invoking metamethods. + /// + /// It checks both the array part and the hash part. + pub fn is_empty(&self) -> bool { + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); + unsafe { + ffi::lua_pushnil(ref_thread); + if ffi::lua_next(ref_thread, self.0.index) == 0 { + return true; + } + ffi::lua_pop(ref_thread, 2); } + false } /// Returns a reference to the metatable of this table, or `None` if no metatable is set. /// - /// Unlike the `getmetatable` Lua function, this method ignores the `__metatable` field. - pub fn get_metatable(&self) -> Option> { - let lua = self.0.lua; + /// Unlike the [`getmetatable`] Lua function, this method ignores the `__metatable` field. + /// + /// [`getmetatable`]: https://www.lua.org/manual/5.4/manual.html#pdf-getmetatable + pub fn metatable(&self) -> Option
{ + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 2); - - lua.push_ref(&self.0); - if ffi::lua_getmetatable(lua.state, -1) == 0 { + if ffi::lua_getmetatable(ref_thread, self.0.index) == 0 { None } else { - Some(Table(lua.pop_ref())) + Some(Table(lua.pop_ref_thread())) } } } @@ -332,61 +646,89 @@ impl<'lua> Table<'lua> { /// /// If `metatable` is `None`, the metatable is removed (if no metatable is set, this does /// nothing). - pub fn set_metatable(&self, metatable: Option>) { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 2); + pub fn set_metatable(&self, metatable: Option
) -> Result<()> { + #[cfg(feature = "luau")] + if self.is_readonly() { + return Err(Error::runtime("attempt to modify a readonly table")); + } - lua.push_ref(&self.0); - if let Some(metatable) = metatable { - lua.push_ref(&metatable.0); + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); + unsafe { + if let Some(metatable) = &metatable { + ffi::lua_pushvalue(ref_thread, metatable.0.index); } else { - ffi::lua_pushnil(lua.state); + ffi::lua_pushnil(ref_thread); } - ffi::lua_setmetatable(lua.state, -2); + ffi::lua_setmetatable(ref_thread, self.0.index); } + Ok(()) + } + + /// Returns true if the table has metatable attached. + #[doc(hidden)] + #[inline] + pub fn has_metatable(&self) -> bool { + let lua = self.0.lua.lock(); + unsafe { !get_metatable_ptr(lua.ref_thread(), self.0.index).is_null() } } /// Sets `readonly` attribute on the table. - /// - /// Requires `feature = "luau"` #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn set_readonly(&self, enabled: bool) { - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); unsafe { - lua.ref_thread_exec(|refthr| { - ffi::lua_setreadonly(refthr, self.0.index, enabled as _); - if !enabled { - // Reset "safeenv" flag - ffi::lua_setsafeenv(refthr, self.0.index, 0); - } - }); + ffi::lua_setreadonly(ref_thread, self.0.index, enabled as _); + if !enabled { + // Reset "safeenv" flag + ffi::lua_setsafeenv(ref_thread, self.0.index, 0); + } } } /// Returns `readonly` attribute of the table. - /// - /// Requires `feature = "luau"` #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub fn is_readonly(&self) -> bool { - let lua = self.0.lua; - unsafe { lua.ref_thread_exec(|refthr| ffi::lua_getreadonly(refthr, self.0.index) != 0) } + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); + unsafe { ffi::lua_getreadonly(ref_thread, self.0.index) != 0 } } - /// Consume this table and return an iterator over the pairs of the table. + /// Controls `safeenv` attribute on the table. /// - /// This works like the Lua `pairs` function, but does not invoke the `__pairs` metamethod. + /// This a special flag that activates some performance optimizations for environment tables. + /// In particular, it controls: + /// - Optimization of import resolution (cache values of constant keys). + /// - Fast-path for built-in iteration with pairs/ipairs. + /// - Fast-path for some built-in functions (fastcall). /// - /// The pairs are wrapped in a [`Result`], since they are lazily converted to `K` and `V` types. + /// For `safeenv` environments, monkey patching or modifying values may not work as expected. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_safeenv(&self, enabled: bool) { + let lua = self.0.lua.lock(); + unsafe { ffi::lua_setsafeenv(lua.ref_thread(), self.0.index, enabled as _) }; + } + + /// Converts this table to a generic C pointer. /// - /// # Note + /// Different tables will give different pointers. + /// There is no way to convert the pointer back to its original value. /// - /// While this method consumes the `Table` object, it can not prevent code from mutating the - /// table while the iteration is in progress. Refer to the [Lua manual] for information about - /// the consequences of such mutation. + /// Typically this function is used only for hashing and debug information. + #[inline] + pub fn to_pointer(&self) -> *const c_void { + self.0.to_pointer() + } + + /// Returns an iterator over the pairs of the table. + /// + /// This works like the Lua `pairs` function, but does not invoke the `__pairs` metamethod. + /// + /// The pairs are wrapped in a [`Result`], since they are lazily converted to `K` and `V` types. /// /// # Examples /// @@ -407,30 +749,47 @@ impl<'lua> Table<'lua> { /// # } /// ``` /// - /// [`Result`]: crate::Result /// [Lua manual]: http://www.lua.org/manual/5.4/manual.html#pdf-next - pub fn pairs, V: FromLua<'lua>>(self) -> TablePairs<'lua, K, V> { + pub fn pairs(&self) -> TablePairs<'_, K, V> { TablePairs { - table: self.0, + guard: self.0.lua.lock(), + table: self, key: Some(Nil), _phantom: PhantomData, } } - /// Consume this table and return an iterator over all values in the sequence part of the table. - /// - /// The iterator will yield all values `t[1]`, `t[2]`, and so on, until a `nil` value is - /// encountered. This mirrors the behavior of Lua's `ipairs` function and will invoke the - /// `__index` metamethod according to the usual rules. However, the deprecated `__ipairs` - /// metatable will not be called. + /// Iterates over the pairs of the table, invoking the given closure on each pair. /// - /// Just like [`pairs`], the values are wrapped in a [`Result`]. - /// - /// # Note + /// This method is similar to [`Table::pairs`], but optimized for performance. + /// It does not invoke the `__pairs` metamethod. + pub fn for_each(&self, mut f: impl FnMut(K, V) -> Result<()>) -> Result<()> + where + K: FromLua, + V: FromLua, + { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 5)?; + + lua.push_ref(&self.0); + ffi::lua_pushnil(state); + while ffi::lua_next(state, -2) != 0 { + let k = K::from_stack(-2, &lua)?; + let v = lua.pop::()?; + f(k, v)?; + } + } + Ok(()) + } + + /// Returns an iterator over all values in the sequence part of the table. /// - /// While this method consumes the `Table` object, it can not prevent code from mutating the - /// table while the iteration is in progress. Refer to the [Lua manual] for information about - /// the consequences of such mutation. + /// The iterator will yield all values `t[1]`, `t[2]` and so on, until a `nil` value is + /// encountered. This mirrors the behavior of Lua's `ipairs` function but does not invoke + /// any metamethods. /// /// # Examples /// @@ -454,277 +813,482 @@ impl<'lua> Table<'lua> { /// # Ok(()) /// # } /// ``` - /// - /// [`pairs`]: #method.pairs - /// [`Result`]: crate::Result - /// [Lua manual]: http://www.lua.org/manual/5.4/manual.html#pdf-next - pub fn sequence_values>(self) -> TableSequence<'lua, V> { + pub fn sequence_values(&self) -> TableSequence<'_, V> { TableSequence { - table: self.0, - index: Some(1), + guard: self.0.lua.lock(), + table: self, + index: 1, len: None, - raw: false, _phantom: PhantomData, } } - /// Consume this table and return an iterator over all values in the sequence part of the table. - /// - /// Unlike the `sequence_values`, does not invoke `__index` metamethod when iterating. + /// Iterates over the sequence part of the table, invoking the given closure on each value. /// - /// [`sequence_values`]: #method.sequence_values - pub fn raw_sequence_values>(self) -> TableSequence<'lua, V> { - TableSequence { - table: self.0, - index: Some(1), - len: None, - raw: true, - _phantom: PhantomData, - } + /// This methods is similar to [`Table::sequence_values`], but optimized for performance. + #[doc(hidden)] + pub fn for_each_value(&self, f: impl FnMut(V) -> Result<()>) -> Result<()> { + self.for_each_value_by_len(None, f) } - #[cfg(any(feature = "serialize"))] - pub(crate) fn raw_sequence_values_by_len>( - self, - len: Option, - ) -> TableSequence<'lua, V> { - let len = len.unwrap_or_else(|| self.raw_len()); - TableSequence { - table: self.0, - index: Some(1), - len: Some(len), - raw: true, - _phantom: PhantomData, + fn for_each_value_by_len( + &self, + len: impl Into>, + mut f: impl FnMut(V) -> Result<()>, + ) -> Result<()> { + let len = len.into(); + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + check_stack(state, 4)?; + + lua.push_ref(&self.0); + for i in 1.. { + if len.map(|len| i > len).unwrap_or(false) { + break; + } + let t = ffi::lua_rawgeti(state, -1, i as _); + if len.is_none() && t == ffi::LUA_TNIL { + break; + } + f(lua.pop::()?)?; + } } + Ok(()) } - #[cfg(feature = "serialize")] - pub(crate) fn is_array(&self) -> bool { - let lua = self.0.lua; + /// Sets element value at position `idx` without invoking metamethods. + #[doc(hidden)] + pub fn raw_seti(&self, idx: usize, value: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 3); + #[cfg(feature = "luau")] + self.check_readonly_write(&lua)?; + + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_ref(&self.0); - if ffi::lua_getmetatable(lua.state, -1) == 0 { - return false; + value.push_into_stack(&lua)?; + + let idx = idx.try_into().unwrap(); + if lua.unlikely_memory_error() { + ffi::lua_rawseti(state, -2, idx); + } else { + protect_lua!(state, 2, 0, |state| ffi::lua_rawseti(state, -2, idx))?; } - crate::serde::push_array_metatable(lua.state); - ffi::lua_rawequal(lua.state, -1, -2) != 0 } + Ok(()) } -} -impl<'lua> PartialEq for Table<'lua> { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} + /// Checks if the table has the array metatable attached. + #[cfg(feature = "serde")] + fn has_array_metatable(&self) -> bool { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 3); -impl<'lua> AsRef> for Table<'lua> { - #[inline] - fn as_ref(&self) -> &Self { - self + lua.push_ref(&self.0); + if ffi::lua_getmetatable(state, -1) == 0 { + return false; + } + crate::serde::push_array_metatable(state); + ffi::lua_rawequal(state, -1, -2) != 0 + } } -} -/// An extension trait for `Table`s that provides a variety of convenient functionality. -pub trait TableExt<'lua> { - /// Calls the table as function assuming it has `__call` metamethod. + /// If the table is an array, returns the number of non-nil elements and max index. /// - /// The metamethod is called with the table as its first argument, followed by the passed arguments. - fn call(&self, args: A) -> Result - where - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>; - - /// Asynchronously calls the table as function assuming it has `__call` metamethod. + /// Returns `None` if the table is not an array. /// - /// The metamethod is called with the table as its first argument, followed by the passed arguments. - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn call_async<'fut, A, R>(&self, args: A) -> LocalBoxFuture<'fut, Result> - where - 'lua: 'fut, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut; + /// This operation has O(n) complexity. + #[cfg(feature = "serde")] + fn find_array_len(&self) -> Option<(usize, usize)> { + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); + unsafe { + let _sg = StackGuard::new(ref_thread); - /// Gets the function associated to `key` from the table and executes it, - /// passing the table itself along with `args` as function arguments. - /// - /// This is a shortcut for - /// `table.get::<_, Function>(key)?.call((table.clone(), arg1, ..., argN))` - /// - /// This might invoke the `__index` metamethod. - fn call_method(&self, key: K, args: A) -> Result - where - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>; + let (mut count, mut max_index) = (0, 0); + ffi::lua_pushnil(ref_thread); + while ffi::lua_next(ref_thread, self.0.index) != 0 { + if ffi::lua_type(ref_thread, -2) != ffi::LUA_TNUMBER { + return None; + } - /// Gets the function associated to `key` from the table and executes it, - /// passing `args` as function arguments. - /// - /// This is a shortcut for - /// `table.get::<_, Function>(key)?.call(args)` - /// - /// This might invoke the `__index` metamethod. - fn call_function(&self, key: K, args: A) -> Result - where - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>; + let k = ffi::lua_tonumber(ref_thread, -2); + if k.trunc() != k || k < 1.0 { + return None; + } + max_index = std::cmp::max(max_index, k as usize); + count += 1; + ffi::lua_pop(ref_thread, 1); + } + Some((count, max_index)) + } + } - /// Gets the function associated to `key` from the table and asynchronously executes it, - /// passing the table itself along with `args` as function arguments and returning Future. - /// - /// Requires `feature = "async"` + /// Determines if the table should be encoded as an array or a map. /// - /// This might invoke the `__index` metamethod. - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn call_async_method<'fut, K, A, R>(&self, key: K, args: A) -> LocalBoxFuture<'fut, Result> - where - 'lua: 'fut, - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut; - - /// Gets the function associated to `key` from the table and asynchronously executes it, - /// passing `args` as function arguments and returning Future. + /// The algorithm is the following: + /// 1. If `detect_mixed_tables` is enabled, iterate over all keys in the table checking is they + /// all are positive integers. If non-array key is found, return `None` (encode as map). + /// Otherwise check the sparsity of the array. Too sparse arrays are encoded as maps. /// - /// Requires `feature = "async"` + /// 2. If `detect_mixed_tables` is disabled, check if the table has a positive length or has the + /// array metatable. If so, encode as array. If the table is empty and + /// `encode_empty_tables_as_array` is enabled, encode as array. /// - /// This might invoke the `__index` metamethod. - #[cfg(feature = "async")] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn call_async_function<'fut, K, A, R>( + /// Returns the length of the array if it should be encoded as an array. + #[cfg(feature = "serde")] + pub(crate) fn encode_as_array(&self, options: crate::serde::de::Options) -> Option { + if options.detect_mixed_tables { + if let Some((len, max_idx)) = self.find_array_len() { + // If the array is too sparse, serialize it as a map instead + if len < 10 || len * 2 >= max_idx { + return Some(max_idx); + } + } + } else { + let len = self.raw_len(); + if len > 0 || self.has_array_metatable() { + return Some(len); + } + if options.encode_empty_tables_as_array && self.is_empty() { + return Some(0); + } + } + None + } + + #[cfg(feature = "luau")] + #[inline(always)] + fn check_readonly_write(&self, lua: &RawLua) -> Result<()> { + if unsafe { ffi::lua_getreadonly(lua.ref_thread(), self.0.index) != 0 } { + return Err(Error::runtime("attempt to modify a readonly table")); + } + Ok(()) + } + + pub(crate) fn fmt_pretty( &self, - key: K, - args: A, - ) -> LocalBoxFuture<'fut, Result> - where - 'lua: 'fut, - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut; + fmt: &mut fmt::Formatter, + ident: usize, + visited: &mut HashSet<*const c_void>, + ) -> fmt::Result { + visited.insert(self.to_pointer()); + + // Collect key/value pairs into a vector so we can sort them + let mut pairs = self.pairs::().flatten().collect::>(); + // Sort keys + pairs.sort_by(|(a, _), (b, _)| a.sort_cmp(b)); + let is_sequence = (pairs.iter().enumerate()) + .all(|(i, (k, _))| matches!(k, Value::Integer(n) if *n == (i + 1) as Integer)); + if pairs.is_empty() { + return write!(fmt, "{{}}"); + } + writeln!(fmt, "{{")?; + if is_sequence { + // Format as list + for (_, value) in pairs { + write!(fmt, "{}", " ".repeat(ident + 2))?; + value.fmt_pretty(fmt, true, ident + 2, visited)?; + writeln!(fmt, ",")?; + } + } else { + fn is_simple_key(key: &[u8]) -> bool { + key.iter().take(1).all(|c| c.is_ascii_alphabetic() || *c == b'_') + && key.iter().all(|c| c.is_ascii_alphanumeric() || *c == b'_') + } + + for (key, value) in pairs { + match key { + Value::String(key) if is_simple_key(&key.as_bytes()) => { + write!(fmt, "{}{}", " ".repeat(ident + 2), key.display())?; + write!(fmt, " = ")?; + } + _ => { + write!(fmt, "{}[", " ".repeat(ident + 2))?; + key.fmt_pretty(fmt, false, ident + 2, visited)?; + write!(fmt, "] = ")?; + } + } + value.fmt_pretty(fmt, true, ident + 2, visited)?; + writeln!(fmt, ",")?; + } + } + write!(fmt, "{}}}", " ".repeat(ident)) + } +} + +impl fmt::Debug for Table { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + if fmt.alternate() { + return self.fmt_pretty(fmt, 0, &mut HashSet::new()); + } + fmt.debug_tuple("Table").field(&self.0).finish() + } +} + +impl PartialEq<[T]> for Table +where + T: IntoLua + Clone, +{ + fn eq(&self, other: &[T]) -> bool { + let lua = self.0.lua.lock(); + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + assert_stack(state, 4); + + lua.push_ref(&self.0); + + let len = ffi::lua_rawlen(state, -1); + for i in 0..len { + ffi::lua_rawgeti(state, -1, (i + 1) as _); + let val = lua.pop_value(); + if val == Nil { + return i == other.len(); + } + match other.get(i).map(|v| v.clone().into_lua(lua.lua())) { + Some(Ok(other_val)) if val == other_val => continue, + _ => return false, + } + } + } + true + } } -impl<'lua> TableExt<'lua> for Table<'lua> { - fn call(&self, args: A) -> Result +impl PartialEq<&[T]> for Table +where + T: IntoLua + Clone, +{ + #[inline] + fn eq(&self, other: &&[T]) -> bool { + self == *other + } +} + +impl PartialEq<[T; N]> for Table +where + T: IntoLua + Clone, +{ + #[inline] + fn eq(&self, other: &[T; N]) -> bool { + self == &other[..] + } +} + +impl ObjectLike for Table { + #[inline] + fn get(&self, key: impl IntoLua) -> Result { + self.get(key) + } + + #[inline] + fn set(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()> { + self.set(key, value) + } + + #[inline] + fn call(&self, args: impl IntoLuaMulti) -> Result where - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, + R: FromLuaMulti, { // Convert table to a function and call via pcall that respects the `__call` metamethod. Function(self.0.clone()).call(args) } #[cfg(feature = "async")] - fn call_async<'fut, A, R>(&self, args: A) -> LocalBoxFuture<'fut, Result> + #[inline] + fn call_async(&self, args: impl IntoLuaMulti) -> AsyncCallFuture where - 'lua: 'fut, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut, + R: FromLuaMulti, { Function(self.0.clone()).call_async(args) } - fn call_method(&self, key: K, args: A) -> Result + #[inline] + fn call_method(&self, name: &str, args: impl IntoLuaMulti) -> Result where - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, + R: FromLuaMulti, { - let lua = self.0.lua; - let mut args = args.to_lua_multi(lua)?; - args.push_front(Value::Table(self.clone())); - self.get::<_, Function>(key)?.call(args) + self.call_function(name, (self, args)) } - fn call_function(&self, key: K, args: A) -> Result + #[cfg(feature = "async")] + fn call_async_method(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture where - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, + R: FromLuaMulti, { - self.get::<_, Function>(key)?.call(args) + self.call_async_function(name, (self, args)) } - #[cfg(feature = "async")] - fn call_async_method<'fut, K, A, R>(&self, key: K, args: A) -> LocalBoxFuture<'fut, Result> - where - 'lua: 'fut, - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut, - { - let lua = self.0.lua; - let mut args = match args.to_lua_multi(lua) { - Ok(args) => args, - Err(e) => return Box::pin(future::err(e)), - }; - args.push_front(Value::Table(self.clone())); - self.call_async_function(key, args) + #[inline] + fn call_function(&self, name: &str, args: impl IntoLuaMulti) -> Result { + match self.get(name)? { + Value::Function(func) => func.call(args), + val => { + let msg = format!("attempt to call a {} value (function '{name}')", val.type_name()); + Err(Error::runtime(msg)) + } + } } #[cfg(feature = "async")] - fn call_async_function<'fut, K, A, R>(&self, key: K, args: A) -> LocalBoxFuture<'fut, Result> + #[inline] + fn call_async_function(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture where - 'lua: 'fut, - K: ToLua<'lua>, - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua> + 'fut, + R: FromLuaMulti, { - match self.get::<_, Function>(key) { - Ok(func) => func.call_async(args), - Err(e) => Box::pin(future::err(e)), + match self.get(name) { + Ok(Value::Function(func)) => func.call_async(args), + Ok(val) => { + let msg = format!("attempt to call a {} value (function '{name}')", val.type_name()); + AsyncCallFuture::error(Error::RuntimeError(msg)) + } + Err(err) => AsyncCallFuture::error(err), + } + } + + #[inline] + fn to_string(&self) -> Result { + Value::Table(Table(self.0.clone())).to_string() + } + + #[inline] + fn to_value(&self) -> Value { + Value::Table(self.clone()) + } + + #[inline] + fn weak_lua(&self) -> &WeakLua { + &self.0.lua + } +} + +/// A wrapped [`Table`] with customized serialization behavior. +#[cfg(feature = "serde")] +pub(crate) struct SerializableTable<'a> { + table: &'a Table, + options: crate::serde::de::Options, + visited: Rc>>, +} + +#[cfg(feature = "serde")] +impl Serialize for Table { + #[inline] + fn serialize(&self, serializer: S) -> StdResult { + SerializableTable::new(self, Default::default(), Default::default()).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> SerializableTable<'a> { + #[inline] + pub(crate) fn new( + table: &'a Table, + options: crate::serde::de::Options, + visited: Rc>>, + ) -> Self { + Self { + table, + options, + visited, } } } -#[cfg(feature = "serialize")] -impl<'lua> Serialize for Table<'lua> { +impl TableSequence<'_, V> { + /// Sets the length (hint) of the sequence. + #[cfg(feature = "serde")] + pub(crate) fn with_len(mut self, len: usize) -> Self { + self.len = Some(len); + self + } +} + +#[cfg(feature = "serde")] +impl Serialize for SerializableTable<'_> { fn serialize(&self, serializer: S) -> StdResult where S: Serializer, { - thread_local! { - static VISITED: RefCell> = RefCell::new(FxHashSet::default()); - } + use crate::serde::de::{MapPairs, RecursionGuard, check_value_for_skip}; + use crate::value::SerializableValue; + + let convert_result = |res: Result<()>, serialize_err: Option| match res { + Ok(v) => Ok(v), + Err(Error::SerializeError(_)) if serialize_err.is_some() => Err(serialize_err.unwrap()), + Err(Error::SerializeError(msg)) => Err(serde::ser::Error::custom(msg)), + Err(err) => Err(serde::ser::Error::custom(err.to_string())), + }; - let lua = self.0.lua; - let ptr = unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, self.0.index)) }; - let res = VISITED.with(|visited| { - { - let mut visited = visited.borrow_mut(); - if visited.contains(&ptr) { - return Err(ser::Error::custom("recursive table detected")); + let options = self.options; + let visited = &self.visited; + let _guard = RecursionGuard::new(self.table, visited); + + // Array + if let Some(len) = self.table.encode_as_array(self.options) { + let mut seq = serializer.serialize_seq(Some(len))?; + let mut serialize_err = None; + let res = self.table.for_each_value_by_len::(len, |value| { + let skip = check_value_for_skip(&value, self.options, visited) + .map_err(|err| Error::SerializeError(err.to_string()))?; + if skip { + // continue iteration + return Ok(()); } - visited.insert(ptr); - } + seq.serialize_element(&SerializableValue::new(&value, options, Some(visited))) + .map_err(|err| { + serialize_err = Some(err); + Error::SerializeError(String::new()) + }) + }); + convert_result(res, serialize_err)?; + return seq.end(); + } - let len = self.raw_len() as usize; - if len > 0 || self.is_array() { - let mut seq = serializer.serialize_seq(Some(len))?; - for v in self.clone().raw_sequence_values_by_len::(None) { - let v = v.map_err(serde::ser::Error::custom)?; - seq.serialize_element(&v)?; - } - return seq.end(); + // HashMap + let mut map = serializer.serialize_map(None)?; + let mut serialize_err = None; + let mut process_pair = |key, value| { + let skip_key = check_value_for_skip(&key, self.options, visited) + .map_err(|err| Error::SerializeError(err.to_string()))?; + let skip_value = check_value_for_skip(&value, self.options, visited) + .map_err(|err| Error::SerializeError(err.to_string()))?; + if skip_key || skip_value { + // continue iteration + return Ok(()); } + map.serialize_entry( + &SerializableValue::new(&key, options, Some(visited)), + &SerializableValue::new(&value, options, Some(visited)), + ) + .map_err(|err| { + serialize_err = Some(err); + Error::SerializeError(String::new()) + }) + }; - let mut map = serializer.serialize_map(None)?; - for kv in self.clone().pairs::() { - let (k, v) = kv.map_err(serde::ser::Error::custom)?; - map.serialize_entry(&k, &v)?; - } - map.end() - }); - VISITED.with(|visited| { - visited.borrow_mut().remove(&ptr); - }); - res + let res = if !self.options.sort_keys { + // Fast track + self.table.for_each(process_pair) + } else { + MapPairs::new(self.table, self.options.sort_keys) + .map_err(serde::ser::Error::custom)? + .try_for_each(|kv| { + let (key, value) = kv?; + process_pair(key, value) + }) + }; + convert_result(res, serialize_err)?; + map.end() } } @@ -733,40 +1297,41 @@ impl<'lua> Serialize for Table<'lua> { /// This struct is created by the [`Table::pairs`] method. /// /// [`Table::pairs`]: crate::Table::pairs -pub struct TablePairs<'lua, K, V> { - table: LuaRef<'lua>, - key: Option>, +pub struct TablePairs<'a, K, V> { + guard: LuaGuard, + table: &'a Table, + key: Option, _phantom: PhantomData<(K, V)>, } -impl<'lua, K, V> Iterator for TablePairs<'lua, K, V> +impl Iterator for TablePairs<'_, K, V> where - K: FromLua<'lua>, - V: FromLua<'lua>, + K: FromLua, + V: FromLua, { type Item = Result<(K, V)>; fn next(&mut self) -> Option { if let Some(prev_key) = self.key.take() { - let lua = self.table.lua; + let lua: &RawLua = &self.guard; + let state = lua.state(); let res = (|| unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; - - lua.push_ref(&self.table); - lua.push_value(prev_key)?; - - let next = protect_lua!(lua.state, 2, ffi::LUA_MULTRET, |state| { - ffi::lua_next(state, -2) - })?; - if next != 0 { - let value = lua.pop_value(); - let key = lua.pop_value(); + let _sg = StackGuard::new(state); + check_stack(state, 5)?; + + lua.push_ref(&self.table.0); + lua.push_value(&prev_key)?; + + // It must be safe to call `lua_next` unprotected as deleting a key from a table is + // a permitted operation. + // It fails only if the key is not found (never existed) which seems impossible scenario. + if ffi::lua_next(state, -2) != 0 { + let key = lua.stack_value(-2, None); Ok(Some(( key.clone(), - K::from_lua(key, lua)?, - V::from_lua(value, lua)?, + K::from_lua(key, lua.lua())?, + V::from_stack(-1, lua)?, ))) } else { Ok(None) @@ -792,50 +1357,44 @@ where /// This struct is created by the [`Table::sequence_values`] method. /// /// [`Table::sequence_values`]: crate::Table::sequence_values -pub struct TableSequence<'lua, V> { - table: LuaRef<'lua>, - index: Option, - len: Option, - raw: bool, +pub struct TableSequence<'a, V> { + guard: LuaGuard, + table: &'a Table, + index: Integer, + len: Option, _phantom: PhantomData, } -impl<'lua, V> Iterator for TableSequence<'lua, V> -where - V: FromLua<'lua>, -{ +impl Iterator for TableSequence<'_, V> { type Item = Result; fn next(&mut self) -> Option { - if let Some(index) = self.index.take() { - let lua = self.table.lua; - - let res = (|| unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 1 + if self.raw { 0 } else { 3 })?; - - lua.push_ref(&self.table); - let res = if self.raw { - ffi::lua_rawgeti(lua.state, -1, index) - } else { - protect_lua!(lua.state, 1, 1, |state| ffi::lua_geti(state, -1, index))? - }; - match res { - ffi::LUA_TNIL if index > self.len.unwrap_or(0) => Ok(None), - _ => Ok(Some((index, lua.pop_value()))), - } - })(); + let lua: &RawLua = &self.guard; + let state = lua.state(); + unsafe { + let _sg = StackGuard::new(state); + if let Err(err) = check_stack(state, 1) { + return Some(Err(err)); + } - match res { - Ok(Some((index, r))) => { - self.index = Some(index + 1); - Some(V::from_lua(r, lua)) + lua.push_ref(&self.table.0); + match ffi::lua_rawgeti(state, -1, self.index) { + ffi::LUA_TNIL if self.index as usize > self.len.unwrap_or(0) => None, + _ => { + self.index += 1; + Some(V::from_stack(-1, lua)) } - Ok(None) => None, - Err(err) => Some(Err(err)), } - } else { - None } } } + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(Table: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(Table: Send, Sync); +} diff --git a/src/thread.rs b/src/thread.rs index dcdb565f..65d209ed 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -1,83 +1,152 @@ -use std::cmp; -use std::os::raw::c_int; +//! Lua thread (coroutine) handling. +//! +//! This module provides types for creating and working with Lua coroutines from Rust. +//! Coroutines allow cooperative multitasking within a single Lua state by suspending and +//! resuming execution at well-defined yield points. +//! +//! # Basic Usage +//! +//! Threads are created via [`Lua::create_thread`] and driven by calling [`Thread::resume`]: +//! +//! ```rust +//! # use mlua::{Lua, Result, Thread}; +//! # fn main() -> Result<()> { +//! let lua = Lua::new(); +//! let thread: Thread = lua.load(r#" +//! coroutine.create(function(a, b) +//! coroutine.yield(a + b) +//! return a * b +//! end) +//! "#).eval()?; +//! +//! assert_eq!(thread.resume::((3, 4))?, 7); +//! assert_eq!(thread.resume::(())?, 12); +//! # Ok(()) +//! # } +//! ``` +//! +//! # Async Support +//! +//! When the `async` feature is enabled, a [`Thread`] can be converted into an [`AsyncThread`] +//! via [`Thread::into_async`], which implements both [`Future`] and [`Stream`]. +//! This integrates Lua coroutines naturally with Rust async runtimes such as Tokio. +//! +//! [`Lua::create_thread`]: crate::Lua::create_thread +//! [`Future`]: std::future::Future +//! [`Stream`]: futures_util::stream::Stream + +use std::fmt; +use std::os::raw::{c_int, c_void}; use crate::error::{Error, Result}; -use crate::ffi; -use crate::types::LuaRef; -use crate::util::{check_stack, error_traceback, pop_error, StackGuard}; -use crate::value::{FromLuaMulti, ToLuaMulti}; - -#[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", -))] use crate::function::Function; +use crate::state::RawLua; +use crate::traits::{FromLuaMulti, IntoLuaMulti}; +use crate::types::{LuaType, ValueRef}; +use crate::util::{StackGuard, check_stack, error_traceback_thread, pop_error}; + +#[cfg(not(feature = "luau"))] +use crate::{ + debug::{Debug, HookTriggers}, + types::HookKind, +}; #[cfg(feature = "async")] use { - crate::{ - lua::{Lua, ASYNC_POLL_PENDING}, - value::{MultiValue, Value}, - }, - futures_core::{future::Future, stream::Stream}, + futures_util::stream::Stream, std::{ - cell::RefCell, + future::Future, marker::PhantomData, pin::Pin, + ptr::NonNull, task::{Context, Poll, Waker}, }, }; -/// Status of a Lua thread (or coroutine). +/// Status of a Lua thread (coroutine). #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ThreadStatus { - /// The thread was just created, or is suspended because it has called `coroutine.yield`. + /// The thread was just created or is suspended (yielded). /// /// If a thread is in this state, it can be resumed by calling [`Thread::resume`]. - /// - /// [`Thread::resume`]: crate::Thread::resume Resumable, - /// Either the thread has finished executing, or the thread is currently running. - Unresumable, + /// The thread is currently running. + Running, + /// The thread has finished executing. + Finished, /// The thread has raised a Lua error during execution. Error, } -/// Handle to an internal Lua thread (or coroutine). -#[derive(Clone, Debug)] -pub struct Thread<'lua>(pub(crate) LuaRef<'lua>); +/// Internal representation of a Lua thread status. +/// +/// The number in `New` and `Yielded` variants is the number of arguments pushed +/// to the thread stack. +#[derive(Clone, Copy)] +enum ThreadStatusInner { + New(c_int), + Running, + Yielded(c_int), + Finished, + Error, +} + +impl ThreadStatusInner { + #[cfg(feature = "async")] + #[inline(always)] + fn is_resumable(self) -> bool { + matches!(self, ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_)) + } + + #[cfg(feature = "async")] + #[inline(always)] + fn is_yielded(self) -> bool { + matches!(self, ThreadStatusInner::Yielded(_)) + } +} + +/// Handle to an internal Lua thread (coroutine). +#[derive(Clone, PartialEq)] +pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State); + +#[cfg(feature = "send")] +unsafe impl Send for Thread {} +#[cfg(feature = "send")] +unsafe impl Sync for Thread {} /// Thread (coroutine) representation as an async [`Future`] or [`Stream`]. /// -/// Requires `feature = "async"` -/// -/// [`Future`]: futures_core::future::Future -/// [`Stream`]: futures_core::stream::Stream +/// [`Future`]: std::future::Future +/// [`Stream`]: futures_util::stream::Stream #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] -#[derive(Debug)] -pub struct AsyncThread<'lua, R> { - thread: Thread<'lua>, - args0: RefCell>>>, - ret: PhantomData, +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct AsyncThread { + thread: Thread, + ret: PhantomData R>, recycle: bool, } -impl<'lua> Thread<'lua> { +impl Thread { + /// Returns reference to the Lua state that this thread is associated with. + #[inline(always)] + pub fn state(&self) -> *mut ffi::lua_State { + self.1 + } + /// Resumes execution of this thread. /// - /// Equivalent to `coroutine.resume`. + /// Equivalent to [`coroutine.resume`]. /// - /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it - /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments - /// are passed to its main function. + /// Passes `args` as arguments to the thread. If the coroutine has called [`coroutine.yield`], + /// it will return these arguments. Otherwise, the coroutine wasn't yet started, so the + /// arguments are passed to its main function. /// - /// If the thread is no longer in `Active` state (meaning it has finished execution or - /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok` - /// as follows: + /// If the thread is no longer resumable (meaning it has finished execution or encountered an + /// error), this will return [`Error::CoroutineUnresumable`], otherwise will return `Ok` as + /// follows: /// - /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread + /// If the thread calls [`coroutine.yield`], returns the values passed to `yield`. If the thread /// `return`s values from its main function, returns those. /// /// # Examples @@ -95,128 +164,223 @@ impl<'lua> Thread<'lua> { /// end) /// "#).eval()?; /// - /// assert_eq!(thread.resume::<_, u32>(42)?, 123); - /// assert_eq!(thread.resume::<_, u32>(43)?, 987); + /// assert_eq!(thread.resume::(42)?, 123); + /// assert_eq!(thread.resume::(43)?, 987); /// /// // The coroutine has now returned, so `resume` will fail - /// match thread.resume::<_, u32>(()) { - /// Err(Error::CoroutineInactive) => {}, + /// match thread.resume::(()) { + /// Err(Error::CoroutineUnresumable) => {}, /// unexpected => panic!("unexpected result {:?}", unexpected), /// } /// # Ok(()) /// # } /// ``` - pub fn resume(&self, args: A) -> Result + /// + /// [`coroutine.resume`]: https://www.lua.org/manual/5.4/manual.html#pdf-coroutine.resume + /// [`coroutine.yield`]: https://www.lua.org/manual/5.4/manual.html#pdf-coroutine.yield + pub fn resume(&self, args: impl IntoLuaMulti) -> Result where - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, + R: FromLuaMulti, { - let lua = self.0.lua; - let mut args = args.to_lua_multi(lua)?; - let nargs = args.len() as c_int; - let results = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, cmp::max(nargs + 1, 3))?; - - let thread_state = - lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index)); - - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 { - return Err(Error::CoroutineInactive); - } + let lua = self.0.lua.lock(); + let mut pushed_nargs = match self.status_inner(&lua) { + ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs, + _ => return Err(Error::CoroutineUnresumable), + }; + + let state = lua.state(); + let thread_state = self.state(); + unsafe { + let _sg = StackGuard::new(state); - check_stack(thread_state, nargs)?; - for arg in args.drain_all() { - lua.push_value(arg)?; + let nargs = args.push_into_stack_multi(&lua)?; + if nargs > 0 { + check_stack(thread_state, nargs)?; + ffi::lua_xmove(state, thread_state, nargs); + pushed_nargs += nargs; } - ffi::lua_xmove(lua.state, thread_state, nargs); - let mut nresults = 0; + let _thread_sg = StackGuard::with_top(thread_state, 0); + let (_, nresults) = self.resume_inner(&lua, pushed_nargs)?; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - let ret = ffi::lua_resume(thread_state, lua.state, nargs, &mut nresults as *mut c_int); - if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD { - protect_lua!(lua.state, 0, 0, |_| error_traceback(thread_state))?; - return Err(pop_error(thread_state, ret)); - } + R::from_stack_multi(nresults, &lua) + } + } - let mut results = args; // Reuse MultiValue container - check_stack(lua.state, nresults + 2)?; // 2 is extra for `lua.pop_value()` below - ffi::lua_xmove(thread_state, lua.state, nresults); + /// Resumes execution of this thread, immediately raising an error. + /// + /// This is a Luau specific extension. + #[cfg(feature = "luau")] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn resume_error(&self, error: impl crate::IntoLua) -> Result + where + R: FromLuaMulti, + { + let lua = self.0.lua.lock(); + match self.status_inner(&lua) { + ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => {} + _ => return Err(Error::CoroutineUnresumable), + }; + + let state = lua.state(); + let thread_state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + + check_stack(state, 1)?; + error.push_into_stack(&lua)?; + ffi::lua_xmove(state, thread_state, 1); + + let _thread_sg = StackGuard::with_top(thread_state, 0); + let (_, nresults) = self.resume_inner(&lua, ffi::LUA_RESUMEERROR)?; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); + + R::from_stack_multi(nresults, &lua) + } + } - for _ in 0..nresults { - results.push_front(lua.pop_value()); + /// Resumes execution of this thread. + /// + /// It's similar to `resume()` but leaves `nresults` values on the thread stack. + unsafe fn resume_inner(&self, lua: &RawLua, nargs: c_int) -> Result<(ThreadStatusInner, c_int)> { + let state = lua.state(); + let thread_state = self.state(); + let mut nresults = 0; + #[cfg(not(feature = "luau"))] + let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int); + #[cfg(feature = "luau")] + let ret = ffi::lua_resumex(thread_state, state, nargs, &mut nresults as *mut c_int); + match ret { + ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), + ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), + ffi::LUA_ERRMEM => { + // Don't call error handler for memory errors + Err(pop_error(thread_state, ret)) } - results - }; - R::from_lua_multi(results, lua) + _ => { + check_stack(state, 3)?; + protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?; + Err(pop_error(state, ret)) + } + } } /// Gets the status of the thread. pub fn status(&self) -> ThreadStatus { - let lua = self.0.lua; + match self.status_inner(&self.0.lua.lock()) { + ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => ThreadStatus::Resumable, + ThreadStatusInner::Running => ThreadStatus::Running, + ThreadStatusInner::Finished => ThreadStatus::Finished, + ThreadStatusInner::Error => ThreadStatus::Error, + } + } + + /// Gets the status of the thread (internal implementation). + fn status_inner(&self, lua: &RawLua) -> ThreadStatusInner { + let thread_state = self.state(); + if thread_state == lua.state() { + // The thread is currently running + return ThreadStatusInner::Running; + } + let status = unsafe { ffi::lua_status(thread_state) }; + let top = unsafe { ffi::lua_gettop(thread_state) }; + match status { + ffi::LUA_YIELD => ThreadStatusInner::Yielded(top), + ffi::LUA_OK if top > 0 => ThreadStatusInner::New(top - 1), + ffi::LUA_OK => ThreadStatusInner::Finished, + _ => ThreadStatusInner::Error, + } + } + + /// Returns `true` if this thread is resumable (meaning it can be resumed by calling + /// [`Thread::resume`]). + #[inline(always)] + pub fn is_resumable(&self) -> bool { + self.status() == ThreadStatus::Resumable + } + + /// Returns `true` if this thread is currently running. + #[inline(always)] + pub fn is_running(&self) -> bool { + self.status() == ThreadStatus::Running + } + + /// Returns `true` if this thread has finished executing. + #[inline(always)] + pub fn is_finished(&self) -> bool { + self.status() == ThreadStatus::Finished + } + + /// Returns `true` if this thread has raised a Lua error during execution. + #[inline(always)] + pub fn is_error(&self) -> bool { + self.status() == ThreadStatus::Error + } + + /// Sets a hook function that will periodically be called as Lua code executes. + /// + /// This function is similar or [`Lua::set_hook`] except that it sets for the thread. + /// You can have multiple hooks for different threads. + /// + /// To remove a hook call [`Thread::remove_hook`]. + /// + /// [`Lua::set_hook`]: crate::Lua::set_hook + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn set_hook(&self, triggers: HookTriggers, callback: F) -> Result<()> + where + F: Fn(&crate::Lua, &Debug) -> Result + crate::MaybeSend + 'static, + { + let lua = self.0.lua.lock(); unsafe { - let thread_state = - lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index)); - - let status = ffi::lua_status(thread_state); - if status != ffi::LUA_OK && status != ffi::LUA_YIELD { - ThreadStatus::Error - } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 { - ThreadStatus::Resumable - } else { - ThreadStatus::Unresumable - } + lua.set_thread_hook( + self.state(), + HookKind::Thread(triggers, crate::types::XRc::new(callback)), + ) + } + } + + /// Removes any hook function from this thread. + #[cfg(not(feature = "luau"))] + #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] + pub fn remove_hook(&self) { + let _lua = self.0.lua.lock(); + unsafe { + ffi::lua_sethook(self.state(), None, 0, 0); } } /// Resets a thread /// /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables. - /// Returns a error in case of either the original error that stopped the thread or errors + /// Returns an error in case of either the original error that stopped the thread or errors /// in closing methods. /// - /// In [LuaJIT] and Luau: resets to the initial state of a newly created Lua thread. + /// In Luau: resets to the initial state of a newly created Lua thread. /// Lua threads in arbitrary states (like yielded or errored) can be reset properly. /// + /// Other Lua versions can reset only new or finished threads. + /// /// Sets a Lua function for the thread afterwards. /// - /// Requires `feature = "lua54"` OR `feature = "luajit,vendored"` OR `feature = "luau"` - /// - /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread - /// [LuaJIT]: https://github.com/openresty/luajit2#lua_resetthread - #[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", - ))] - pub fn reset(&self, func: Function<'lua>) -> Result<()> { - let lua = self.0.lua; + /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_closethread + pub fn reset(&self, func: Function) -> Result<()> { + let lua = self.0.lua.lock(); + let thread_state = self.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; + let status = self.status_inner(&lua); + self.reset_inner(status)?; - lua.push_ref(&self.0); - let thread_state = ffi::lua_tothread(lua.state, -1); - - #[cfg(feature = "lua54")] - let status = ffi::lua_resetthread(thread_state); - #[cfg(feature = "lua54")] - if status != ffi::LUA_OK { - return Err(pop_error(thread_state, status)); - } - #[cfg(all(feature = "luajit", feature = "vendored"))] - ffi::lua_resetthread(lua.state, thread_state); - #[cfg(feature = "luau")] - ffi::lua_resetthread(thread_state); - - lua.push_ref(&func.0); - ffi::lua_xmove(lua.state, thread_state, 1); + // Push function to the top of the thread stack + ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); #[cfg(feature = "luau")] { - // Inherit `LUA_GLOBALSINDEX` from the caller - ffi::lua_xpush(lua.state, thread_state, ffi::LUA_GLOBALSINDEX); + // Inherit `LUA_GLOBALSINDEX` from the main thread + ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX); ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX); } @@ -224,27 +388,64 @@ impl<'lua> Thread<'lua> { } } - /// Converts Thread to an AsyncThread which implements [`Future`] and [`Stream`] traits. + unsafe fn reset_inner(&self, status: ThreadStatusInner) -> Result<()> { + match status { + ThreadStatusInner::New(_) => { + // The thread is new, so we can just set the top to 0 + ffi::lua_settop(self.state(), 0); + Ok(()) + } + ThreadStatusInner::Running => Err(Error::runtime("cannot reset a running thread")), + ThreadStatusInner::Finished => Ok(()), + #[cfg(not(any(feature = "lua55", feature = "lua54", feature = "luau")))] + ThreadStatusInner::Yielded(_) | ThreadStatusInner::Error => { + Err(Error::runtime("cannot reset non-finished thread")) + } + #[cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))] + ThreadStatusInner::Yielded(_) | ThreadStatusInner::Error => { + let thread_state = self.state(); + + #[cfg(all(feature = "lua54", not(feature = "vendored")))] + let status = ffi::lua_resetthread(thread_state); + #[cfg(any(feature = "lua55", all(feature = "lua54", feature = "vendored")))] + let status = { + let lua = self.0.lua.lock(); + ffi::lua_closethread(thread_state, lua.state()) + }; + #[cfg(any(feature = "lua55", feature = "lua54"))] + if status != ffi::LUA_OK { + return Err(pop_error(thread_state, status)); + } + #[cfg(feature = "luau")] + ffi::lua_resetthread(thread_state); + + Ok(()) + } + } + } + + /// Converts [`Thread`] to an [`AsyncThread`] which implements [`Future`] and [`Stream`] traits. + /// + /// Only resumable threads can be converted to [`AsyncThread`]. /// - /// `args` are passed as arguments to the thread function for first call. - /// The object calls [`resume()`] while polling and also allows to run rust futures + /// `args` are pushed to the thread stack and will be used when the thread is resumed. + /// The object calls [`resume`] while polling and also allow to run Rust futures /// to completion using an executor. /// - /// Using AsyncThread as a Stream allows to iterate through `coroutine.yield()` - /// values whereas Future version discards that values and poll until the final + /// Using [`AsyncThread`] as a [`Stream`] allow to iterate through [`coroutine.yield`] + /// values whereas [`Future`] version discards that values and poll until the final /// one (returned from the thread function). /// - /// Requires `feature = "async"` - /// - /// [`Future`]: futures_core::future::Future - /// [`Stream`]: futures_core::stream::Stream - /// [`resume()`]: https://www.lua.org/manual/5.4/manual.html#lua_resume + /// [`Future`]: std::future::Future + /// [`Stream`]: futures_util::stream::Stream + /// [`resume`]: https://www.lua.org/manual/5.4/manual.html#lua_resume + /// [`coroutine.yield`]: https://www.lua.org/manual/5.4/manual.html#pdf-coroutine.yield /// /// # Examples /// /// ``` /// # use mlua::{Lua, Result, Thread}; - /// use futures::stream::TryStreamExt; + /// use futures_util::stream::TryStreamExt; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// # let lua = Lua::new(); @@ -258,7 +459,7 @@ impl<'lua> Thread<'lua> { /// end) /// "#).eval()?; /// - /// let mut stream = thread.into_async::<_, i64>(1); + /// let mut stream = thread.into_async::(1)?; /// let mut sum = 0; /// while let Some(n) = stream.try_next().await? { /// sum += n; @@ -271,17 +472,31 @@ impl<'lua> Thread<'lua> { /// ``` #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn into_async(self, args: A) -> AsyncThread<'lua, R> + pub fn into_async(self, args: impl IntoLuaMulti) -> Result> where - A: ToLuaMulti<'lua>, - R: FromLuaMulti<'lua>, + R: FromLuaMulti, { - let args = args.to_lua_multi(self.0.lua); - AsyncThread { - thread: self, - args0: RefCell::new(Some(args)), - ret: PhantomData, - recycle: false, + let lua = self.0.lua.lock(); + if !self.status_inner(&lua).is_resumable() { + return Err(Error::CoroutineUnresumable); + } + + let state = lua.state(); + let thread_state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + + let nargs = args.push_into_stack_multi(&lua)?; + if nargs > 0 { + check_stack(thread_state, nargs)?; + ffi::lua_xmove(state, thread_state, nargs); + } + + Ok(AsyncThread { + thread: self, + ret: PhantomData, + recycle: false, + }) } } @@ -290,179 +505,220 @@ impl<'lua> Thread<'lua> { /// Under the hood replaces the global environment table with a new table, /// that performs writes locally and proxies reads to caller's global environment. /// - /// This mode ideally should be used together with the global sandbox mode [`Lua::sandbox()`]. + /// This mode ideally should be used together with the global sandbox mode [`Lua::sandbox`]. /// /// Please note that Luau links environment table with chunk when loading it into Lua state. /// Therefore you need to load chunks into a thread to link with the thread environment. /// + /// [`Lua::sandbox`]: crate::Lua::sandbox + /// /// # Examples /// /// ``` /// # use mlua::{Lua, Result}; + /// # #[cfg(feature = "luau")] /// # fn main() -> Result<()> { /// let lua = Lua::new(); /// let thread = lua.create_thread(lua.create_function(|lua2, ()| { /// lua2.load("var = 123").exec()?; - /// assert_eq!(lua2.globals().get::<_, u32>("var")?, 123); + /// assert_eq!(lua2.globals().get::("var")?, 123); /// Ok(()) /// })?)?; /// thread.sandbox()?; - /// thread.resume(())?; + /// thread.resume::<()>(())?; /// /// // The global environment should be unchanged - /// assert_eq!(lua.globals().get::<_, Option>("var")?, None); + /// assert_eq!(lua.globals().get::>("var")?, None); /// # Ok(()) /// # } - /// ``` /// - /// Requires `feature = "luau"` - #[cfg(any(feature = "luau", docsrs))] + /// # #[cfg(not(feature = "luau"))] + /// # fn main() { } + /// ``` + #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - #[doc(hidden)] pub fn sandbox(&self) -> Result<()> { - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); + let thread_state = self.state(); unsafe { - let thread = lua.ref_thread_exec(|t| ffi::lua_tothread(t, self.0.index)); - check_stack(thread, 1)?; - check_stack(lua.state, 3)?; - // Inherit `LUA_GLOBALSINDEX` from the caller - ffi::lua_xpush(lua.state, thread, ffi::LUA_GLOBALSINDEX); - ffi::lua_replace(thread, ffi::LUA_GLOBALSINDEX); - protect_lua!(lua.state, 0, 0, |_| ffi::luaL_sandboxthread(thread)) + check_stack(thread_state, 3)?; + check_stack(state, 3)?; + protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state)) } } + + /// Converts this thread to a generic C pointer. + /// + /// There is no way to convert the pointer back to its original value. + /// + /// Typically this function is used only for hashing and debug information. + #[inline] + pub fn to_pointer(&self) -> *const c_void { + self.0.to_pointer() + } } -impl<'lua> PartialEq for Thread<'lua> { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 +impl fmt::Debug for Thread { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_tuple("Thread").field(&self.0).finish() } } +impl LuaType for Thread { + const TYPE_ID: c_int = ffi::LUA_TTHREAD; +} + #[cfg(feature = "async")] -impl<'lua, R> AsyncThread<'lua, R> { - #[inline] +impl AsyncThread { + #[inline(always)] pub(crate) fn set_recyclable(&mut self, recyclable: bool) { self.recycle = recyclable; } } #[cfg(feature = "async")] -#[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", -))] -impl<'lua, R> Drop for AsyncThread<'lua, R> { +impl Drop for AsyncThread { fn drop(&mut self) { + #[allow(clippy::collapsible_if)] if self.recycle { - unsafe { - self.thread.0.lua.recycle_thread(&mut self.thread); + if let Some(lua) = self.thread.0.lua.try_lock() { + unsafe { + let mut status = self.thread.status_inner(&lua); + if matches!(status, ThreadStatusInner::Yielded(0)) { + // The thread is dropped while yielded, resume it with the "terminate" signal + ffi::lua_pushlightuserdata(self.thread.1, crate::Lua::poll_terminate().0); + if let Ok((new_status, _)) = self.thread.resume_inner(&lua, 1) { + // `new_status` should always be `ThreadStatusInner::Yielded(0)` + status = new_status; + } + } + + // For Lua 5.4 this also closes all pending to-be-closed variables + if self.thread.reset_inner(status).is_ok() { + lua.recycle_thread(&mut self.thread); + } + } } } } } #[cfg(feature = "async")] -impl<'lua, R> Stream for AsyncThread<'lua, R> -where - R: FromLuaMulti<'lua>, -{ +impl Stream for AsyncThread { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let lua = self.thread.0.lua; - - match self.thread.status() { - ThreadStatus::Resumable => {} + let lua = self.thread.0.lua.lock(); + let nargs = match self.thread.status_inner(&lua) { + ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs, _ => return Poll::Ready(None), }; - let _wg = WakerGuard::new(lua, cx.waker().clone()); - let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { - self.thread.resume(args?)? - } else { - self.thread.resume(())? - }; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(&lua, cx.waker()); + + let (status, nresults) = (self.thread).resume_inner(&lua, nargs)?; + + if status.is_yielded() { + if nresults == 1 && is_poll_pending(thread_state) { + return Poll::Pending; + } + // Continue polling + cx.waker().wake_by_ref(); + } - if is_poll_pending(&ret) { - return Poll::Pending; - } + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - cx.waker().wake_by_ref(); - Poll::Ready(Some(R::from_lua_multi(ret, lua))) + Poll::Ready(Some(R::from_stack_multi(nresults, &lua))) + } } } #[cfg(feature = "async")] -impl<'lua, R> Future for AsyncThread<'lua, R> -where - R: FromLuaMulti<'lua>, -{ +impl Future for AsyncThread { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let lua = self.thread.0.lua; - - match self.thread.status() { - ThreadStatus::Resumable => {} - _ => return Poll::Ready(Err(Error::CoroutineInactive)), + let lua = self.thread.0.lua.lock(); + let nargs = match self.thread.status_inner(&lua) { + ThreadStatusInner::New(nargs) | ThreadStatusInner::Yielded(nargs) => nargs, + _ => return Poll::Ready(Err(Error::CoroutineUnresumable)), }; - let _wg = WakerGuard::new(lua, cx.waker().clone()); - let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() { - self.thread.resume(args?)? - } else { - self.thread.resume(())? - }; + let state = lua.state(); + let thread_state = self.thread.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::with_top(thread_state, 0); + let _wg = WakerGuard::new(&lua, cx.waker()); + + let (status, nresults) = self.thread.resume_inner(&lua, nargs)?; + + if status.is_yielded() { + if !(nresults == 1 && is_poll_pending(thread_state)) { + // Ignore values returned via yield() + cx.waker().wake_by_ref(); + } + return Poll::Pending; + } - if is_poll_pending(&ret) { - return Poll::Pending; - } + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); - if let ThreadStatus::Resumable = self.thread.status() { - // Ignore value returned via yield() - cx.waker().wake_by_ref(); - return Poll::Pending; + Poll::Ready(R::from_stack_multi(nresults, &lua)) } - - Poll::Ready(R::from_lua_multi(ret, lua)) } } #[cfg(feature = "async")] #[inline(always)] -fn is_poll_pending(val: &MultiValue) -> bool { - match val.iter().enumerate().last() { - Some((0, Value::LightUserData(ud))) => { - std::ptr::eq(ud.0 as *const u8, &ASYNC_POLL_PENDING as *const u8) - } - _ => false, - } +unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool { + ffi::lua_tolightuserdata(state, -1) == crate::Lua::poll_pending().0 } #[cfg(feature = "async")] -struct WakerGuard<'lua> { - lua: &'lua Lua, - prev: Option, +struct WakerGuard<'lua, 'a> { + lua: &'lua RawLua, + prev: NonNull, + _phantom: PhantomData<&'a ()>, } #[cfg(feature = "async")] -impl<'lua> WakerGuard<'lua> { +impl<'lua, 'a> WakerGuard<'lua, 'a> { #[inline] - pub fn new(lua: &Lua, waker: Waker) -> Result { - unsafe { - let prev = lua.set_waker(Some(waker)); - Ok(WakerGuard { lua, prev }) - } + pub fn new(lua: &'lua RawLua, waker: &'a Waker) -> Result> { + let prev = lua.set_waker(NonNull::from(waker)); + Ok(WakerGuard { + lua, + prev, + _phantom: PhantomData, + }) } } #[cfg(feature = "async")] -impl<'lua> Drop for WakerGuard<'lua> { +impl Drop for WakerGuard<'_, '_> { fn drop(&mut self) { - unsafe { - self.lua.set_waker(self.prev.take()); - } + self.lua.set_waker(self.prev); } } + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(Thread: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(Thread: Send, Sync); + #[cfg(all(feature = "async", not(feature = "send")))] + static_assertions::assert_not_impl_any!(AsyncThread<()>: Send); + #[cfg(all(feature = "async", feature = "send"))] + static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync); +} diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 00000000..405a95f7 --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,259 @@ +//! Core conversion and extension traits. +//! +//! This module provides the fundamental traits for converting values between Rust and Lua, +//! and for defining native Lua callable functions. + +use std::os::raw::c_int; +use std::sync::Arc; + +use crate::error::{Error, Result}; +use crate::multi::MultiValue; +use crate::private::Sealed; +use crate::state::{Lua, RawLua, WeakLua}; +use crate::util::{check_stack, parse_lookup_path, short_type_name}; +use crate::value::Value; + +#[cfg(feature = "async")] +use crate::function::AsyncCallFuture; + +/// Trait for types convertible to [`Value`]. +pub trait IntoLua: Sized { + /// Performs the conversion. + fn into_lua(self, lua: &Lua) -> Result; + + /// Pushes the value into the Lua stack. + /// + /// # Safety + /// This method does not check Lua stack space. + #[doc(hidden)] + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_value(&self.into_lua(lua.lua())?) + } +} + +/// Trait for types convertible from [`Value`]. +pub trait FromLua: Sized { + /// Performs the conversion. + fn from_lua(value: Value, lua: &Lua) -> Result; + + /// Performs the conversion for an argument (eg. function argument). + /// + /// `i` is the argument index (position), + /// `to` is a function name that received the argument. + #[doc(hidden)] + #[inline] + fn from_lua_arg(arg: Value, i: usize, to: Option<&str>, lua: &Lua) -> Result { + Self::from_lua(arg, lua).map_err(|err| Error::BadArgument { + to: to.map(|s| s.to_string()), + pos: i, + name: None, + cause: Arc::new(err), + }) + } + + /// Performs the conversion for a value in the Lua stack at index `idx`. + #[doc(hidden)] + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Self::from_lua(lua.stack_value(idx, None), lua.lua()) + } + + /// Same as `from_lua_arg` but for a value in the Lua stack at index `idx`. + #[doc(hidden)] + #[inline] + unsafe fn from_stack_arg(idx: c_int, i: usize, to: Option<&str>, lua: &RawLua) -> Result { + Self::from_stack(idx, lua).map_err(|err| Error::BadArgument { + to: to.map(|s| s.to_string()), + pos: i, + name: None, + cause: Arc::new(err), + }) + } +} + +/// Trait for types convertible to any number of Lua values. +/// +/// This is a generalization of [`IntoLua`], allowing any number of resulting Lua values instead of +/// just one. Any type that implements [`IntoLua`] will automatically implement this trait. +pub trait IntoLuaMulti: Sized { + /// Performs the conversion. + fn into_lua_multi(self, lua: &Lua) -> Result; + + /// Pushes the values into the Lua stack. + /// + /// Returns number of pushed values. + #[doc(hidden)] + #[inline] + unsafe fn push_into_stack_multi(self, lua: &RawLua) -> Result { + let values = self.into_lua_multi(lua.lua())?; + let len: c_int = values.len().try_into().unwrap(); + unsafe { + check_stack(lua.state(), len + 1)?; + for val in &values { + lua.push_value(val)?; + } + } + Ok(len) + } +} + +/// Trait for types that can be created from an arbitrary number of Lua values. +/// +/// This is a generalization of [`FromLua`], allowing an arbitrary number of Lua values to +/// participate in the conversion. Any type that implements [`FromLua`] will automatically +/// implement this trait. +pub trait FromLuaMulti: Sized { + /// Performs the conversion. + /// + /// In case `values` contains more values than needed to perform the conversion, the excess + /// values should be ignored. This reflects the semantics of Lua when calling a function or + /// assigning values. Similarly, if not enough values are given, conversions should assume that + /// any missing values are nil. + fn from_lua_multi(values: MultiValue, lua: &Lua) -> Result; + + /// Performs the conversion for a list of arguments. + /// + /// `i` is an index (position) of the first argument, + /// `to` is a function name that received the arguments. + #[doc(hidden)] + #[inline] + fn from_lua_args(args: MultiValue, i: usize, to: Option<&str>, lua: &Lua) -> Result { + let _ = (i, to); + Self::from_lua_multi(args, lua) + } + + /// Performs the conversion for a number of values in the Lua stack. + #[doc(hidden)] + #[inline] + unsafe fn from_stack_multi(nvals: c_int, lua: &RawLua) -> Result { + let mut values = MultiValue::with_capacity(nvals as usize); + for idx in 0..nvals { + values.push_back(lua.stack_value(-nvals + idx, None)); + } + Self::from_lua_multi(values, lua.lua()) + } + + /// Same as `from_lua_args` but for a number of values in the Lua stack. + #[doc(hidden)] + #[inline] + unsafe fn from_stack_args(nargs: c_int, i: usize, to: Option<&str>, lua: &RawLua) -> Result { + let _ = (i, to); + Self::from_stack_multi(nargs, lua) + } +} + +/// A trait for types that can be used as Lua objects (usually table and userdata). +pub trait ObjectLike: Sealed { + /// Gets the value associated to `key` from the object, assuming it has `__index` metamethod. + fn get(&self, key: impl IntoLua) -> Result; + + /// Sets the value associated to `key` in the object, assuming it has `__newindex` metamethod. + fn set(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()>; + + /// Calls the object as a function assuming it has `__call` metamethod. + /// + /// The metamethod is called with the object as its first argument, followed by the passed + /// arguments. + fn call(&self, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti; + + /// Asynchronously calls the object as a function assuming it has `__call` metamethod. + /// + /// The metamethod is called with the object as its first argument, followed by the passed + /// arguments. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn call_async(&self, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti; + + /// Gets the function associated to key `name` from the object and calls it, + /// passing the object itself along with `args` as function arguments. + fn call_method(&self, name: &str, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti; + + /// Gets the function associated to key `name` from the object and asynchronously calls it, + /// passing the object itself along with `args` as function arguments. + /// + /// This might invoke the `__index` metamethod. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn call_async_method(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti; + + /// Gets the function associated to key `name` from the object and calls it, + /// passing `args` as function arguments. + /// + /// This might invoke the `__index` metamethod. + fn call_function(&self, name: &str, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti; + + /// Gets the function associated to key `name` from the object and asynchronously calls it, + /// passing `args` as function arguments. + /// + /// This might invoke the `__index` metamethod. + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn call_async_function(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti; + + /// Look up a value by a path of keys. + /// + /// The syntax is similar to accessing nested tables in Lua, with additional support for + /// `?` operator to perform safe navigation. + /// + /// For example, the path `a[1].c` is equivalent to `table.a[1].c` in Lua. + /// With `?` operator, `a[1]?.c` is equivalent to `table.a[1] and table.a[1].c or nil` in Lua. + /// + /// Bracket notation rules: + /// - `[123]` - integer keys + /// - `["string key"]` or `['string key']` - string keys (must be quoted) + /// - String keys support escape sequences: `\"`, `\'`, `\\` + fn get_path(&self, path: &str) -> Result { + let mut current = self.to_value(); + for (key, safe_nil) in parse_lookup_path(path)? { + current = match current { + Value::Table(table) => table.get::(key), + Value::UserData(ud) => ud.get::(key), + _ => { + let type_name = current.type_name(); + let err = format!("attempt to index a {type_name} value with key '{key}'"); + Err(Error::runtime(err)) + } + }?; + if safe_nil && (current == Value::Nil || current == Value::NULL) { + break; + } + } + + let lua = self.weak_lua().lock(); + V::from_lua(current, lua.lua()) + } + + /// Converts the object to a string in a human-readable format. + /// + /// This might invoke the `__tostring` metamethod. + fn to_string(&self) -> Result; + + /// Converts the object to a Lua value. + fn to_value(&self) -> Value; + + /// Gets a reference to the associated Lua state. + #[doc(hidden)] + fn weak_lua(&self) -> &WeakLua; +} + +pub(crate) trait ShortTypeName { + #[inline(always)] + fn type_name() -> String { + short_type_name::() + } +} + +impl ShortTypeName for T {} diff --git a/src/types.rs b/src/types.rs index 68a7032e..a6a91030 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,22 +1,27 @@ use std::cell::UnsafeCell; -use std::hash::{Hash, Hasher}; use std::os::raw::{c_int, c_void}; -use std::sync::{Arc, Mutex}; -use std::{fmt, mem, ptr}; -#[cfg(feature = "lua54")] -use std::ffi::CStr; +#[cfg(not(feature = "luau"))] +use crate::debug::{Debug, HookTriggers}; +use crate::error::Result; +use crate::state::{ExtraData, Lua, RawLua}; -#[cfg(feature = "async")] -use futures_core::future::LocalBoxFuture; +// Re-export mutex wrappers +pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak}; -use crate::error::Result; -use crate::ffi; -#[cfg(not(feature = "luau"))] -use crate::hook::Debug; -use crate::lua::{ExtraData, Lua}; -use crate::util::{assert_stack, StackGuard}; -use crate::value::MultiValue; +#[cfg(all(feature = "async", feature = "send"))] +pub(crate) type BoxFuture<'a, T> = futures_util::future::BoxFuture<'a, T>; + +#[cfg(all(feature = "async", not(feature = "send")))] +pub(crate) type BoxFuture<'a, T> = futures_util::future::LocalBoxFuture<'a, T>; + +pub use app_data::{AppData, AppDataRef, AppDataRefMut}; +pub use either::Either; +pub use registry_key::RegistryKey; +pub(crate) use value_ref::ValueRef; + +#[cfg(feature = "async")] +pub(crate) use value_ref::ValueRefIndex; /// Type of Lua integer numbers. pub type Integer = ffi::lua_Integer; @@ -27,161 +32,143 @@ pub type Number = ffi::lua_Number; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct LightUserData(pub *mut c_void); -pub(crate) type Callback<'lua, 'a> = - Box) -> Result> + 'a>; +#[cfg(feature = "send")] +unsafe impl Send for LightUserData {} +#[cfg(feature = "send")] +unsafe impl Sync for LightUserData {} + +#[cfg(feature = "send")] +type CallbackFn<'a> = dyn Fn(&RawLua, c_int) -> Result + Send + 'a; + +#[cfg(not(feature = "send"))] +type CallbackFn<'a> = dyn Fn(&RawLua, c_int) -> Result + 'a; + +pub(crate) type Callback = Box>; +pub(crate) type CallbackPtr = *const CallbackFn<'static>; + +pub(crate) type ScopedCallback<'s> = Box Result + 's>; pub(crate) struct Upvalue { pub(crate) data: T, - pub(crate) extra: Arc>, + pub(crate) extra: XRc>, } -pub(crate) type CallbackUpvalue = Upvalue>; +pub(crate) type CallbackUpvalue = Upvalue>; -#[cfg(feature = "async")] -pub(crate) type AsyncCallback<'lua, 'a> = - Box) -> LocalBoxFuture<'lua, Result>> + 'a>; +#[cfg(all(feature = "async", feature = "send"))] +pub(crate) type AsyncCallback = + Box Fn(&'a RawLua, c_int) -> BoxFuture<'a, Result> + Send + 'static>; + +#[cfg(all(feature = "async", not(feature = "send")))] +pub(crate) type AsyncCallback = + Box Fn(&'a RawLua, c_int) -> BoxFuture<'a, Result> + 'static>; #[cfg(feature = "async")] -pub(crate) type AsyncCallbackUpvalue = Upvalue>; +pub(crate) type AsyncCallbackUpvalue = Upvalue; #[cfg(feature = "async")] -pub(crate) type AsyncPollUpvalue = Upvalue>>>; +pub(crate) type AsyncPollUpvalue = Upvalue>>>; -/// Type to set next Luau VM action after executing interrupt function. -#[cfg(any(feature = "luau", doc))] -#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +/// Type to set next Lua VM action after executing interrupt or hook function. pub enum VmState { Continue, + /// Yield the current thread. + /// + /// Supported by Lua 5.3+ and Luau. Yield, } +#[cfg(not(feature = "luau"))] +pub(crate) enum HookKind { + Global, + Thread(HookTriggers, HookCallback), +} + #[cfg(all(feature = "send", not(feature = "luau")))] -pub(crate) type HookCallback = Arc Result<()> + Send>; +pub(crate) type HookCallback = XRc Result + Send>; #[cfg(all(not(feature = "send"), not(feature = "luau")))] -pub(crate) type HookCallback = Arc Result<()>>; +pub(crate) type HookCallback = XRc Result>; + +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type InterruptCallback = XRc Result + Send>; + +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type InterruptCallback = XRc Result>; + +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type ThreadCreationCallback = XRc Result<()> + Send>; -#[cfg(all(feature = "luau", feature = "send"))] -pub(crate) type InterruptCallback = Arc Result + Send>; +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type ThreadCreationCallback = XRc Result<()>>; -#[cfg(all(feature = "luau", not(feature = "send")))] -pub(crate) type InterruptCallback = Arc Result>; +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type ThreadCollectionCallback = XRc; -#[cfg(all(feature = "send", feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()> + Send>; +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type ThreadCollectionCallback = XRc; -#[cfg(all(not(feature = "send"), feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()>>; +#[cfg(feature = "send")] +#[cfg(any(feature = "lua55", feature = "lua54"))] +pub(crate) type WarnCallback = XRc Result<()> + Send>; + +#[cfg(not(feature = "send"))] +#[cfg(any(feature = "lua55", feature = "lua54"))] +pub(crate) type WarnCallback = XRc Result<()>>; +/// A trait that adds `Send` requirement if `send` feature is enabled. #[cfg(feature = "send")] pub trait MaybeSend: Send {} #[cfg(feature = "send")] impl MaybeSend for T {} +/// A trait that adds `Send` requirement if `send` feature is enabled. #[cfg(not(feature = "send"))] pub trait MaybeSend {} #[cfg(not(feature = "send"))] impl MaybeSend for T {} -pub(crate) struct DestructedUserdataMT; - -/// An auto generated key into the Lua registry. -/// -/// This is a handle to a value stored inside the Lua registry. It is not automatically -/// garbage collected on Drop, but it can be removed with [`Lua::remove_registry_value`], -/// and instances not manually removed can be garbage collected with [`Lua::expire_registry_values`]. -/// -/// Be warned, If you place this into Lua via a [`UserData`] type or a rust callback, it is *very -/// easy* to accidentally cause reference cycles that the Lua garbage collector cannot resolve. -/// Instead of placing a [`RegistryKey`] into a [`UserData`] type, prefer instead to use -/// [`AnyUserData::set_user_value`] / [`AnyUserData::get_user_value`]. -/// -/// [`UserData`]: crate::UserData -/// [`RegistryKey`]: crate::RegistryKey -/// [`Lua::remove_registry_value`]: crate::Lua::remove_registry_value -/// [`Lua::expire_registry_values`]: crate::Lua::expire_registry_values -/// [`AnyUserData::set_user_value`]: crate::AnyUserData::set_user_value -/// [`AnyUserData::get_user_value`]: crate::AnyUserData::get_user_value -pub struct RegistryKey { - pub(crate) registry_id: c_int, - pub(crate) unref_list: Arc>>>, -} - -impl fmt::Debug for RegistryKey { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "RegistryKey({})", self.registry_id) - } -} - -impl Hash for RegistryKey { - fn hash(&self, state: &mut H) { - self.registry_id.hash(state) - } -} +/// A trait that adds `Sync` requirement if `send` feature is enabled. +#[cfg(feature = "send")] +pub trait MaybeSync: Sync {} +#[cfg(feature = "send")] +impl MaybeSync for T {} -impl PartialEq for RegistryKey { - fn eq(&self, other: &RegistryKey) -> bool { - self.registry_id == other.registry_id && Arc::ptr_eq(&self.unref_list, &other.unref_list) - } -} +/// A trait that adds `Sync` requirement if `send` feature is enabled. +#[cfg(not(feature = "send"))] +pub trait MaybeSync {} +#[cfg(not(feature = "send"))] +impl MaybeSync for T {} -impl Eq for RegistryKey {} +pub(crate) struct DestructedUserdata; -impl Drop for RegistryKey { - fn drop(&mut self) { - let mut unref_list = mlua_expect!(self.unref_list.lock(), "unref list poisoned"); - if let Some(list) = unref_list.as_mut() { - list.push(self.registry_id); - } - } +pub(crate) trait LuaType { + const TYPE_ID: c_int; } -impl RegistryKey { - // Destroys the RegistryKey without adding to the drop list - pub(crate) fn take(self) -> c_int { - let registry_id = self.registry_id; - unsafe { - ptr::read(&self.unref_list); - mem::forget(self); - } - registry_id - } +impl LuaType for bool { + const TYPE_ID: c_int = ffi::LUA_TBOOLEAN; } -pub(crate) struct LuaRef<'lua> { - pub(crate) lua: &'lua Lua, - pub(crate) index: c_int, +impl LuaType for Number { + const TYPE_ID: c_int = ffi::LUA_TNUMBER; } -impl<'lua> fmt::Debug for LuaRef<'lua> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Ref({})", self.index) - } +impl LuaType for LightUserData { + const TYPE_ID: c_int = ffi::LUA_TLIGHTUSERDATA; } -impl<'lua> Clone for LuaRef<'lua> { - fn clone(&self) -> Self { - self.lua.clone_ref(self) - } -} +mod app_data; +mod registry_key; +mod sync; +mod value_ref; -impl<'lua> Drop for LuaRef<'lua> { - fn drop(&mut self) { - if self.index > 0 { - self.lua.drop_ref(self); - } - } -} +#[cfg(test)] +mod assertions { + use super::*; -impl<'lua> PartialEq for LuaRef<'lua> { - fn eq(&self, other: &Self) -> bool { - let lua = self.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - assert_stack(lua.state, 2); - lua.push_ref(self); - lua.push_ref(other); - ffi::lua_rawequal(lua.state, -1, -2) == 1 - } - } + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(ValueRef: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(ValueRef: Send, Sync); } diff --git a/src/types/app_data.rs b/src/types/app_data.rs new file mode 100644 index 00000000..0ec7a9f2 --- /dev/null +++ b/src/types/app_data.rs @@ -0,0 +1,212 @@ +use std::any::{Any, TypeId}; +use std::cell::{BorrowError, BorrowMutError, Cell, Ref, RefCell, RefMut, UnsafeCell}; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::result::Result as StdResult; + +use rustc_hash::FxHashMap; + +use super::MaybeSend; +use crate::state::LuaGuard; + +#[cfg(not(feature = "send"))] +type Container = UnsafeCell>>>; + +#[cfg(feature = "send")] +type Container = UnsafeCell>>>; + +/// A container for arbitrary data associated with the Lua state. +#[derive(Debug, Default)] +pub struct AppData { + container: Container, + borrow: Cell, +} + +impl AppData { + #[track_caller] + pub(crate) fn insert(&self, data: T) -> Option { + match self.try_insert(data) { + Ok(data) => data, + Err(_) => panic!("cannot mutably borrow app data container"), + } + } + + pub(crate) fn try_insert(&self, data: T) -> StdResult, T> { + if self.borrow.get() != 0 { + return Err(data); + } + // SAFETY: we checked that there are no other references to the container + Ok(unsafe { &mut *self.container.get() } + .insert(TypeId::of::(), RefCell::new(Box::new(data))) + .and_then(|data| data.into_inner().downcast::().ok().map(|data| *data))) + } + + #[inline] + #[track_caller] + pub(crate) fn borrow(&self, guard: Option) -> Option> { + match self.try_borrow(guard) { + Ok(data) => data, + Err(err) => panic!("already mutably borrowed: {err:?}"), + } + } + + pub(crate) fn try_borrow( + &self, + guard: Option, + ) -> Result>, BorrowError> { + let data = unsafe { &*self.container.get() } + .get(&TypeId::of::()) + .map(|c| c.try_borrow()) + .transpose()? + .and_then(|data| Ref::filter_map(data, |data| data.downcast_ref()).ok()); + match data { + Some(data) => { + self.borrow.set(self.borrow.get() + 1); + Ok(Some(AppDataRef { + data, + borrow: &self.borrow, + _guard: guard, + })) + } + None => Ok(None), + } + } + + #[inline] + #[track_caller] + pub(crate) fn borrow_mut(&self, guard: Option) -> Option> { + match self.try_borrow_mut(guard) { + Ok(data) => data, + Err(err) => panic!("already borrowed: {err:?}"), + } + } + + pub(crate) fn try_borrow_mut( + &self, + guard: Option, + ) -> Result>, BorrowMutError> { + let data = unsafe { &*self.container.get() } + .get(&TypeId::of::()) + .map(|c| c.try_borrow_mut()) + .transpose()? + .and_then(|data| RefMut::filter_map(data, |data| data.downcast_mut()).ok()); + match data { + Some(data) => { + self.borrow.set(self.borrow.get() + 1); + Ok(Some(AppDataRefMut { + data, + borrow: &self.borrow, + _guard: guard, + })) + } + None => Ok(None), + } + } + + #[track_caller] + pub(crate) fn remove(&self) -> Option { + if self.borrow.get() != 0 { + panic!("cannot mutably borrow app data container"); + } + // SAFETY: we checked that there are no other references to the container + unsafe { &mut *self.container.get() } + .remove(&TypeId::of::())? + .into_inner() + .downcast::() + .ok() + .map(|data| *data) + } +} + +/// A wrapper type for an immutably borrowed value from an app data container. +/// +/// This type is similar to [`Ref`]. +pub struct AppDataRef<'a, T: ?Sized + 'a> { + data: Ref<'a, T>, + borrow: &'a Cell, + _guard: Option, +} + +impl Drop for AppDataRef<'_, T> { + fn drop(&mut self) { + self.borrow.set(self.borrow.get() - 1); + } +} + +impl Deref for AppDataRef<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl fmt::Display for AppDataRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Debug for AppDataRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +/// A wrapper type for a mutably borrowed value from an app data container. +/// +/// This type is similar to [`RefMut`]. +pub struct AppDataRefMut<'a, T: ?Sized + 'a> { + data: RefMut<'a, T>, + borrow: &'a Cell, + _guard: Option, +} + +impl Drop for AppDataRefMut<'_, T> { + fn drop(&mut self) { + self.borrow.set(self.borrow.get() - 1); + } +} + +impl Deref for AppDataRefMut<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl DerefMut for AppDataRefMut<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.data + } +} + +impl fmt::Display for AppDataRefMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Debug for AppDataRefMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(AppData: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(AppData: Send); + + // Must be !Send + static_assertions::assert_not_impl_any!(AppDataRef<()>: Send); + static_assertions::assert_not_impl_any!(AppDataRefMut<()>: Send); +} diff --git a/src/types/registry_key.rs b/src/types/registry_key.rs new file mode 100644 index 00000000..6df0002e --- /dev/null +++ b/src/types/registry_key.rs @@ -0,0 +1,100 @@ +use std::hash::{Hash, Hasher}; +use std::os::raw::c_int; +use std::sync::Arc; +use std::{fmt, mem, ptr}; + +use parking_lot::Mutex; + +/// An auto generated key into the Lua registry. +/// +/// This is a handle to a value stored inside the Lua registry. It is not automatically +/// garbage collected on Drop, but it can be removed with [`Lua::remove_registry_value`], +/// and instances not manually removed can be garbage collected with +/// [`Lua::expire_registry_values`]. +/// +/// Be warned, If you place this into Lua via a [`UserData`] type or a Rust callback, it is *easy* +/// to accidentally cause reference cycles that the Lua garbage collector cannot resolve. Instead of +/// placing a [`RegistryKey`] into a [`UserData`] type, consider to use +/// [`AnyUserData::set_user_value`]. +/// +/// [`UserData`]: crate::UserData +/// [`RegistryKey`]: crate::RegistryKey +/// [`Lua::remove_registry_value`]: crate::Lua::remove_registry_value +/// [`Lua::expire_registry_values`]: crate::Lua::expire_registry_values +/// [`AnyUserData::set_user_value`]: crate::AnyUserData::set_user_value +pub struct RegistryKey { + pub(crate) registry_id: i32, + pub(crate) unref_list: Arc>>>, +} + +impl fmt::Debug for RegistryKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "RegistryKey({})", self.id()) + } +} + +impl Hash for RegistryKey { + fn hash(&self, state: &mut H) { + self.id().hash(state) + } +} + +impl PartialEq for RegistryKey { + fn eq(&self, other: &RegistryKey) -> bool { + self.id() == other.id() && Arc::ptr_eq(&self.unref_list, &other.unref_list) + } +} + +impl Eq for RegistryKey {} + +impl Drop for RegistryKey { + fn drop(&mut self) { + let registry_id = self.id(); + // We don't need to collect nil slot + if registry_id > ffi::LUA_REFNIL { + let mut unref_list = self.unref_list.lock(); + if let Some(list) = unref_list.as_mut() { + list.push(registry_id); + } + } + } +} + +impl RegistryKey { + /// Creates a new instance of `RegistryKey` + pub(crate) const fn new(id: c_int, unref_list: Arc>>>) -> Self { + RegistryKey { + registry_id: id, + unref_list, + } + } + + /// Returns the underlying Lua reference of this `RegistryKey` + #[inline(always)] + pub fn id(&self) -> c_int { + self.registry_id + } + + /// Sets the unique Lua reference key of this `RegistryKey` + #[inline(always)] + pub(crate) fn set_id(&mut self, id: c_int) { + self.registry_id = id; + } + + /// Destroys the `RegistryKey` without adding to the unref list + pub(crate) fn take(self) -> i32 { + let registry_id = self.id(); + unsafe { + ptr::read(&self.unref_list); + mem::forget(self); + } + registry_id + } +} + +#[cfg(test)] +mod assertions { + use super::*; + + static_assertions::assert_impl_all!(RegistryKey: Send, Sync); +} diff --git a/src/types/sync.rs b/src/types/sync.rs new file mode 100644 index 00000000..753755cf --- /dev/null +++ b/src/types/sync.rs @@ -0,0 +1,77 @@ +#[cfg(feature = "send")] +mod inner { + use parking_lot::{RawMutex, RawThreadId}; + use std::sync::{Arc, Weak}; + + pub(crate) type XRc = Arc; + pub(crate) type XWeak = Weak; + + pub(crate) type ReentrantMutex = parking_lot::ReentrantMutex; + + pub(crate) type ReentrantMutexGuard<'a, T> = parking_lot::ReentrantMutexGuard<'a, T>; + + pub(crate) type ArcReentrantMutexGuard = + parking_lot::lock_api::ArcReentrantMutexGuard; +} + +#[cfg(not(feature = "send"))] +mod inner { + use std::ops::Deref; + use std::rc::{Rc, Weak}; + + pub(crate) type XRc = Rc; + pub(crate) type XWeak = Weak; + + pub(crate) struct ReentrantMutex(T); + + impl ReentrantMutex { + #[inline(always)] + pub(crate) fn new(val: T) -> Self { + ReentrantMutex(val) + } + + #[inline(always)] + pub(crate) fn lock(&self) -> ReentrantMutexGuard<'_, T> { + ReentrantMutexGuard(&self.0) + } + + #[inline(always)] + pub(crate) fn lock_arc(self: &XRc) -> ArcReentrantMutexGuard { + ArcReentrantMutexGuard(Rc::clone(self)) + } + + #[inline(always)] + pub(crate) fn into_lock_arc(self: XRc) -> ArcReentrantMutexGuard { + ArcReentrantMutexGuard(self) + } + + #[inline(always)] + pub(crate) fn data_ptr(&self) -> *const T { + &self.0 as *const _ + } + } + + pub(crate) struct ReentrantMutexGuard<'a, T>(&'a T); + + impl Deref for ReentrantMutexGuard<'_, T> { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.0 + } + } + + pub(crate) struct ArcReentrantMutexGuard(XRc>); + + impl Deref for ArcReentrantMutexGuard { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0.0 + } + } +} + +pub(crate) use inner::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak}; diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs new file mode 100644 index 00000000..564c28f0 --- /dev/null +++ b/src/types/value_ref.rs @@ -0,0 +1,76 @@ +use std::fmt; +use std::os::raw::{c_int, c_void}; + +use super::XRc; +use crate::state::{RawLua, WeakLua}; + +/// A reference to a Lua (complex) value stored in the Lua auxiliary thread. +#[derive(Clone)] +pub struct ValueRef { + pub(crate) lua: WeakLua, + // Keep index separate to avoid additional indirection when accessing it. + pub(crate) index: c_int, + // If `index_count` is `None`, the value does not need to be destroyed. + pub(crate) index_count: Option, +} + +/// A reference to a Lua value index in the auxiliary thread. +/// It's cheap to clone and can be used to track the number of references to a value. +#[derive(Clone)] +pub(crate) struct ValueRefIndex(pub(crate) XRc); + +impl From for ValueRefIndex { + #[inline] + fn from(index: c_int) -> Self { + ValueRefIndex(XRc::new(index)) + } +} + +impl ValueRef { + #[inline] + pub(crate) fn new(lua: &RawLua, index: impl Into) -> Self { + let index = index.into(); + ValueRef { + lua: lua.weak().clone(), + index: *index.0, + index_count: Some(index), + } + } + + #[inline] + pub(crate) fn to_pointer(&self) -> *const c_void { + let lua = self.lua.lock(); + unsafe { ffi::lua_topointer(lua.ref_thread(), self.index) } + } +} + +impl fmt::Debug for ValueRef { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Ref({:p})", self.to_pointer()) + } +} + +impl Drop for ValueRef { + fn drop(&mut self) { + if let Some(ValueRefIndex(index)) = self.index_count.take() { + // It's guaranteed that the inner value returns exactly once. + // This means in particular that the value is not dropped. + if XRc::into_inner(index).is_some() + && let Some(lua) = self.lua.try_lock() + { + unsafe { lua.drop_ref(self) } + } + } + } +} + +impl PartialEq for ValueRef { + fn eq(&self, other: &Self) -> bool { + assert!( + self.lua == other.lua, + "Lua instance passed Value created from a different main Lua state" + ); + let lua = self.lua.lock(); + unsafe { ffi::lua_rawequal(lua.ref_thread(), self.index, other.index) == 1 } + } +} diff --git a/src/userdata.rs b/src/userdata.rs index 0ad5f463..7749c9b0 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -1,42 +1,49 @@ +//! Lua userdata handling. +//! +//! This module provides types for creating and working with Lua userdata from Rust. + use std::any::TypeId; -use std::cell::{Ref, RefCell, RefMut}; +use std::ffi::CStr; use std::fmt; -use std::hash::{Hash, Hasher}; -use std::ops::{Deref, DerefMut}; -use std::os::raw::{c_char, c_int}; -use std::string::String as StdString; +use std::hash::Hash; +use std::os::raw::{c_char, c_void}; + +use crate::Either; +use crate::error::{Error, Result}; +use crate::function::Function; +use crate::state::Lua; +use crate::string::LuaString; +use crate::table::{Table, TablePairs}; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::types::{MaybeSend, MaybeSync, ValueRef}; +use crate::util::{StackGuard, check_stack, get_userdata, push_string, short_type_name, take_userdata}; +use crate::value::Value; #[cfg(feature = "async")] use std::future::Future; -#[cfg(feature = "serialize")] +#[cfg(feature = "serde")] use { serde::ser::{self, Serialize, Serializer}, std::result::Result as StdResult, }; -use crate::error::{Error, Result}; -use crate::ffi; -use crate::function::Function; -use crate::lua::Lua; -use crate::table::{Table, TablePairs}; -use crate::types::{Callback, LuaRef, MaybeSend}; -use crate::util::{check_stack, get_userdata, take_userdata, StackGuard}; -use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti}; - -#[cfg(feature = "async")] -use crate::types::AsyncCallback; - -#[cfg(feature = "lua54")] -pub(crate) const USER_VALUE_MAXSLOT: usize = 8; +// Re-export for convenience +pub(crate) use cell::UserDataStorage; +pub use r#ref::{UserDataOwned, UserDataRef, UserDataRefMut}; +pub use registry::UserDataRegistry; +pub(crate) use registry::{RawUserDataRegistry, UserDataProxy}; +pub(crate) use util::{ + TypeIdHints, borrow_userdata_scoped, borrow_userdata_scoped_mut, collect_userdata, + init_userdata_metatable, +}; /// Kinds of metamethods that can be overridden. /// /// Currently, this mechanism does not allow overriding the `__gc` metamethod, since there is /// generally no need to do so: [`UserData`] implementors can instead just implement `Drop`. -/// -/// [`UserData`]: crate::UserData -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] pub enum MetaMethod { /// The `+` operator. Add, @@ -53,30 +60,53 @@ pub enum MetaMethod { /// The unary minus (`-`) operator. Unm, /// The floor division (//) operator. - /// Requires `feature = "lua54/lua53"` - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))) + )] IDiv, /// The bitwise AND (&) operator. - /// Requires `feature = "lua54/lua53"` - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] BAnd, /// The bitwise OR (|) operator. - /// Requires `feature = "lua54/lua53"` - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] BOr, /// The bitwise XOR (binary ~) operator. - /// Requires `feature = "lua54/lua53"` - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] BXor, /// The bitwise NOT (unary ~) operator. - /// Requires `feature = "lua54/lua53"` - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] BNot, /// The bitwise left shift (<<) operator. - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] Shl, /// The bitwise right shift (>>) operator. - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))) + )] Shr, /// The string concatenation operator `..`. Concat, @@ -98,24 +128,36 @@ pub enum MetaMethod { /// /// This is not an operator, but will be called by methods such as `tostring` and `print`. ToString, + /// The `__todebugstring` metamethod for debug purposes. + /// + /// This is an mlua-specific metamethod that can be used to provide debug representation for + /// userdata. + ToDebugString, /// The `__pairs` metamethod. /// /// This is not an operator, but it will be called by the built-in `pairs` function. - /// - /// Requires `feature = "lua54/lua53/lua52"` #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", - feature = "luajit52", + feature = "luajit52" ))] + #[cfg_attr( + docsrs, + doc(cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luajit52" + ))) + )] Pairs, /// The `__ipairs` metamethod. /// /// This is not an operator, but it will be called by the built-in [`ipairs`] function. /// - /// Requires `feature = "lua52"` - /// /// [`ipairs`]: https://www.lua.org/manual/5.2/manual.html#pdf-ipairs #[cfg(any(feature = "lua52", feature = "luajit52", doc))] #[cfg_attr(docsrs, doc(cfg(any(feature = "lua52", feature = "luajit52"))))] @@ -124,8 +166,6 @@ pub enum MetaMethod { /// /// Executed before the iteration begins, and should return an iterator function like `next` /// (or a custom one). - /// - /// Requires `feature = "lua"` #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] Iter, @@ -133,31 +173,30 @@ pub enum MetaMethod { /// /// Executed when a variable, that marked as to-be-closed, goes out of scope. /// - /// More information about to-be-closed variabled can be found in the Lua 5.4 + /// More information about to-be-closed variables can be found in the Lua 5.4 /// [documentation][lua_doc]. /// - /// Requires `feature = "lua54"` - /// /// [lua_doc]: https://www.lua.org/manual/5.4/manual.html#3.3.8 - #[cfg(any(feature = "lua54"))] + #[cfg(any(feature = "lua55", feature = "lua54"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "lua55", feature = "lua54"))))] Close, - /// A custom metamethod. + /// The `__name`/`__type` metafield. /// - /// Must not be in the protected list: `__gc`, `__metatable`, `__mlua*`. - Custom(StdString), + /// This is not a function, but it's value can be used by `tostring` and `typeof` built-in + /// functions. + #[doc(hidden)] + Type, } -impl PartialEq for MetaMethod { - fn eq(&self, other: &Self) -> bool { - self.name() == other.name() +impl PartialEq for &str { + fn eq(&self, other: &MetaMethod) -> bool { + *self == other.name() } } -impl Eq for MetaMethod {} - -impl Hash for MetaMethod { - fn hash(&self, state: &mut H) { - self.name().hash(state); +impl PartialEq for String { + fn eq(&self, other: &MetaMethod) -> bool { + self == other.name() } } @@ -169,7 +208,7 @@ impl fmt::Display for MetaMethod { impl MetaMethod { /// Returns Lua metamethod name, usually prefixed by two underscores. - pub fn name(&self) -> &str { + pub const fn name(self) -> &'static str { match self { MetaMethod::Add => "__add", MetaMethod::Sub => "__sub", @@ -179,19 +218,19 @@ impl MetaMethod { MetaMethod::Pow => "__pow", MetaMethod::Unm => "__unm", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] MetaMethod::IDiv => "__idiv", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::BAnd => "__band", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::BOr => "__bor", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::BXor => "__bxor", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::BNot => "__bnot", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::Shl => "__shl", - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] MetaMethod::Shr => "__shr", MetaMethod::Concat => "__concat", @@ -203,8 +242,10 @@ impl MetaMethod { MetaMethod::NewIndex => "__newindex", MetaMethod::Call => "__call", MetaMethod::ToString => "__tostring", + MetaMethod::ToDebugString => "__todebugstring", #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -216,93 +257,47 @@ impl MetaMethod { #[cfg(feature = "luau")] MetaMethod::Iter => "__iter", - #[cfg(feature = "lua54")] + #[cfg(any(feature = "lua55", feature = "lua54"))] MetaMethod::Close => "__close", - MetaMethod::Custom(ref name) => name, + #[rustfmt::skip] + MetaMethod::Type => if cfg!(feature = "luau") { "__type" } else { "__name" }, } } - pub(crate) fn validate(self) -> Result { + pub(crate) const fn as_cstr(self) -> &'static CStr { match self { - MetaMethod::Custom(name) if name == "__gc" => Err(Error::MetaMethodRestricted(name)), - MetaMethod::Custom(name) if name == "__metatable" => { - Err(Error::MetaMethodRestricted(name)) - } - MetaMethod::Custom(name) if name.starts_with("__mlua") => { - Err(Error::MetaMethodRestricted(name)) - } - _ => Ok(self), + #[rustfmt::skip] + MetaMethod::Type => if cfg!(feature = "luau") { c"__type" } else { c"__name" }, + _ => unreachable!(), } } -} - -impl From for MetaMethod { - fn from(name: StdString) -> Self { - match name.as_str() { - "__add" => MetaMethod::Add, - "__sub" => MetaMethod::Sub, - "__mul" => MetaMethod::Mul, - "__div" => MetaMethod::Div, - "__mod" => MetaMethod::Mod, - "__pow" => MetaMethod::Pow, - "__unm" => MetaMethod::Unm, - - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__idiv" => MetaMethod::IDiv, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__band" => MetaMethod::BAnd, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bor" => MetaMethod::BOr, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bxor" => MetaMethod::BXor, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bnot" => MetaMethod::BNot, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__shl" => MetaMethod::Shl, - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__shr" => MetaMethod::Shr, - - "__concat" => MetaMethod::Concat, - "__len" => MetaMethod::Len, - "__eq" => MetaMethod::Eq, - "__lt" => MetaMethod::Lt, - "__le" => MetaMethod::Le, - "__index" => MetaMethod::Index, - "__newindex" => MetaMethod::NewIndex, - "__call" => MetaMethod::Call, - "__tostring" => MetaMethod::ToString, - - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luajit52" - ))] - "__pairs" => MetaMethod::Pairs, - #[cfg(any(feature = "lua52", feature = "luajit52"))] - "__ipairs" => MetaMethod::IPairs, - #[cfg(feature = "luau")] - "__iter" => MetaMethod::Iter, - #[cfg(feature = "lua54")] - "__close" => MetaMethod::Close, - - _ => MetaMethod::Custom(name), + pub(crate) fn validate(name: &str) -> Result<&str> { + match name { + "__gc" => Err(Error::MetaMethodRestricted(name.to_string())), + "__metatable" => Err(Error::MetaMethodRestricted(name.to_string())), + _ if name.starts_with("__mlua") => Err(Error::MetaMethodRestricted(name.to_string())), + name => Ok(name), } } } -impl From<&str> for MetaMethod { - fn from(name: &str) -> Self { - MetaMethod::from(name.to_owned()) +impl AsRef for MetaMethod { + fn as_ref(&self) -> &str { + self.name() + } +} + +impl From for String { + #[inline] + fn from(method: MetaMethod) -> Self { + method.name().to_owned() } } /// Method registry for [`UserData`] implementors. -/// -/// [`UserData`]: crate::UserData -pub trait UserDataMethods<'lua, T: UserData> { +pub trait UserDataMethods { /// Add a regular method which accepts a `&T` as the first parameter. /// /// Regular methods are implemented by overriding the `__index` metamethod and returning the @@ -310,90 +305,138 @@ pub trait UserDataMethods<'lua, T: UserData> { /// /// If `add_meta_method` is used to set the `__index` metamethod, the `__index` metamethod will /// be used as a fall-back if no regular method is found. - fn add_method(&mut self, name: &S, method: M) + fn add_method(&mut self, name: impl Into, method: M) where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result; + M: Fn(&Lua, &T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; /// Add a regular method which accepts a `&mut T` as the first parameter. /// /// Refer to [`add_method`] for more information about the implementation. /// - /// [`add_method`]: #method.add_method - fn add_method_mut(&mut self, name: &S, method: M) + /// [`add_method`]: UserDataMethods::add_method + fn add_method_mut(&mut self, name: impl Into, method: M) + where + M: FnMut(&Lua, &mut T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; + + /// Add a method which accepts `T` as the first parameter. + /// + /// The userdata `T` will be moved out of the userdata container. This is useful for + /// methods that need to consume the userdata. + /// + /// The method can be called only once per userdata instance, subsequent calls will result in a + /// [`Error::UserDataDestructed`] error. + fn add_method_once(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(&Lua, T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let method_name = format!("{}.{name}", short_type_name::()); + self.add_function(name, move |lua, (ud, args): (AnyUserData, A)| { + let this = (ud.take()).map_err(|err| Error::bad_self_argument(&method_name, err))?; + method(lua, this, args) + }); + } + + /// Add an async method which accepts a `&T` as the first parameter and returns [`Future`]. + /// + /// Refer to [`add_method`] for more information about the implementation. + /// + /// [`add_method`]: UserDataMethods::add_method + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn add_async_method(&mut self, name: impl Into, method: M) where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result; + T: 'static, + M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; - /// Add an async method which accepts a `T` as the first parameter and returns Future. - /// The passed `T` is cloned from the original value. + /// Add an async method which accepts a `&mut T` as the first parameter and returns [`Future`]. /// /// Refer to [`add_method`] for more information about the implementation. /// - /// Requires `feature = "async"` + /// [`add_method`]: UserDataMethods::add_method + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn add_async_method_mut(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRefMut, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; + + /// Add an async method which accepts a `T` as the first parameter and returns [`Future`]. + /// + /// The userdata `T` will be moved out of the userdata container. This is useful for + /// methods that need to consume the userdata. /// - /// [`add_method`]: #method.add_method + /// The method can be called only once per userdata instance, subsequent calls will result in a + /// [`Error::UserDataDestructed`] error. #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_method(&mut self, name: &S, method: M) + fn add_async_method_once(&mut self, name: impl Into, method: M) where - T: Clone, - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>; - - /// Add a regular method as a function which accepts generic arguments, the first argument will - /// be a [`AnyUserData`] of type `T` if the method is called with Lua method syntax: - /// `my_userdata:my_method(arg1, arg2)`, or it is passed in as the first argument: - /// `my_userdata.my_method(my_userdata, arg1, arg2)`. - /// - /// Prefer to use [`add_method`] or [`add_method_mut`] as they are easier to use. - /// - /// [`AnyUserData`]: crate::AnyUserData - /// [`add_method`]: #method.add_method - /// [`add_method_mut`]: #method.add_method_mut - fn add_function(&mut self, name: &S, function: F) + T: 'static, + M: Fn(Lua, T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let method_name = format!("{}.{name}", short_type_name::()); + self.add_async_function(name, move |lua, (ud, args): (AnyUserData, A)| { + match (ud.take()).map_err(|err| Error::bad_self_argument(&method_name, err)) { + Ok(this) => either::Either::Left(method(lua, this, args)), + Err(err) => either::Either::Right(async move { Err(err) }), + } + }); + } + + /// Add a regular method as a function which accepts generic arguments. + /// + /// The first argument will be a [`AnyUserData`] of type `T` if the method is called with Lua + /// method syntax: `my_userdata:my_method(arg1, arg2)`, or it is passed in as the first + /// argument: `my_userdata.my_method(my_userdata, arg1, arg2)`. + fn add_function(&mut self, name: impl Into, function: F) where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result; + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; /// Add a regular method as a mutable function which accepts generic arguments. /// - /// This is a version of [`add_function`] that accepts a FnMut argument. + /// This is a version of [`add_function`] that accepts a `FnMut` argument. /// - /// [`add_function`]: #method.add_function - fn add_function_mut(&mut self, name: &S, function: F) + /// [`add_function`]: UserDataMethods::add_function + fn add_function_mut(&mut self, name: impl Into, function: F) where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result; + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; - /// Add a regular method as an async function which accepts generic arguments - /// and returns Future. + /// Add a regular method as an async function which accepts generic arguments and returns + /// [`Future`]. /// /// This is an async version of [`add_function`]. /// - /// Requires `feature = "async"` - /// - /// [`add_function`]: #method.add_function + /// [`add_function`]: UserDataMethods::add_function #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_function(&mut self, name: &S, function: F) + fn add_async_function(&mut self, name: impl Into, function: F) where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>; + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; /// Add a metamethod which accepts a `&T` as the first parameter. /// @@ -402,13 +445,12 @@ pub trait UserDataMethods<'lua, T: UserData> { /// This can cause an error with certain binary metamethods that can trigger if only the right /// side has a metatable. To prevent this, use [`add_meta_function`]. /// - /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_method(&mut self, meta: S, method: M) + /// [`add_meta_function`]: UserDataMethods::add_meta_function + fn add_meta_method(&mut self, name: impl Into, method: M) where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result; + M: Fn(&Lua, &T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; /// Add a metamethod as a function which accepts a `&mut T` as the first parameter. /// @@ -417,184 +459,173 @@ pub trait UserDataMethods<'lua, T: UserData> { /// This can cause an error with certain binary metamethods that can trigger if only the right /// side has a metatable. To prevent this, use [`add_meta_function`]. /// - /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_method_mut(&mut self, meta: S, method: M) + /// [`add_meta_function`]: UserDataMethods::add_meta_function + fn add_meta_method_mut(&mut self, name: impl Into, method: M) where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result; + M: FnMut(&Lua, &mut T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; - /// Add an async metamethod which accepts a `T` as the first parameter and returns Future. - /// The passed `T` is cloned from the original value. + /// Add an async metamethod which accepts a `&T` as the first parameter and returns [`Future`]. /// /// This is an async version of [`add_meta_method`]. /// - /// Requires `feature = "async"` + /// [`add_meta_method`]: UserDataMethods::add_meta_method + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + #[cfg_attr( + docsrs, + doc(cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))) + )] + fn add_async_meta_method(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; + + /// Add an async metamethod which accepts a `&mut T` as the first parameter and returns + /// [`Future`]. + /// + /// This is an async version of [`add_meta_method_mut`]. /// - /// [`add_meta_method`]: #method.add_meta_method + /// [`add_meta_method_mut`]: UserDataMethods::add_meta_method_mut #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_meta_method(&mut self, name: S, method: M) + fn add_async_meta_method_mut(&mut self, name: impl Into, method: M) where - T: Clone, - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>; + T: 'static, + M: Fn(Lua, UserDataRefMut, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; /// Add a metamethod which accepts generic arguments. /// /// Metamethods for binary operators can be triggered if either the left or right argument to /// the binary operator has a metatable, so the first argument here is not necessarily a /// userdata of type `T`. - fn add_meta_function(&mut self, meta: S, function: F) + fn add_meta_function(&mut self, name: impl Into, function: F) where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result; + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; /// Add a metamethod as a mutable function which accepts generic arguments. /// - /// This is a version of [`add_meta_function`] that accepts a FnMut argument. + /// This is a version of [`add_meta_function`] that accepts a `FnMut` argument. /// - /// [`add_meta_function`]: #method.add_meta_function - fn add_meta_function_mut(&mut self, meta: S, function: F) + /// [`add_meta_function`]: UserDataMethods::add_meta_function + fn add_meta_function_mut(&mut self, name: impl Into, function: F) where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result; + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti; - /// Add a metamethod which accepts generic arguments and returns Future. + /// Add a metamethod which accepts generic arguments and returns [`Future`]. /// /// This is an async version of [`add_meta_function`]. /// - /// Requires `feature = "async"` - /// - /// [`add_meta_function`]: #method.add_meta_function + /// [`add_meta_function`]: UserDataMethods::add_meta_function #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_meta_function(&mut self, name: S, function: F) + #[cfg_attr( + docsrs, + doc(cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))) + )] + fn add_async_meta_function(&mut self, name: impl Into, function: F) where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>; - - // - // Below are internal methods used in generated code - // - - #[doc(hidden)] - fn add_callback(&mut self, _name: Vec, _callback: Callback<'lua, 'static>) {} - - #[doc(hidden)] - #[cfg(feature = "async")] - fn add_async_callback(&mut self, _name: Vec, _callback: AsyncCallback<'lua, 'static>) {} - - #[doc(hidden)] - fn add_meta_callback(&mut self, _meta: MetaMethod, _callback: Callback<'lua, 'static>) {} - - #[doc(hidden)] - #[cfg(feature = "async")] - fn add_async_meta_callback( - &mut self, - _meta: MetaMethod, - _callback: AsyncCallback<'lua, 'static>, - ) { - } + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti; } /// Field registry for [`UserData`] implementors. -/// -/// [`UserData`]: crate::UserData -pub trait UserDataFields<'lua, T: UserData> { - /// Add a regular field getter as a method which accepts a `&T` as the parameter. +pub trait UserDataFields { + /// Add a static field to the [`UserData`]. /// - /// Regular field getters are implemented by overriding the `__index` metamethod and returning the + /// Static fields are implemented by updating the `__index` metamethod and returning the /// accessed field. This allows them to be used with the expected `userdata.field` syntax. /// + /// Static fields are usually shared between all instances of the [`UserData`] of the same type. + /// + /// If `add_meta_method` is used to set the `__index` metamethod, it will + /// be used as a fall-back if no regular field or method are found. + fn add_field(&mut self, name: impl Into, value: V) + where + V: IntoLua + 'static; + + /// Add a regular field getter as a method which accepts a `&T` as the parameter. + /// + /// Regular field getters are implemented by overriding the `__index` metamethod and returning + /// the accessed field. This allows them to be used with the expected `userdata.field` syntax. + /// /// If `add_meta_method` is used to set the `__index` metamethod, the `__index` metamethod will /// be used as a fall-back if no regular field or method are found. - fn add_field_method_get(&mut self, name: &S, method: M) + fn add_field_method_get(&mut self, name: impl Into, method: M) where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result; + M: Fn(&Lua, &T) -> Result + MaybeSend + 'static, + R: IntoLua; /// Add a regular field setter as a method which accepts a `&mut T` as the first parameter. /// - /// Regular field setters are implemented by overriding the `__newindex` metamethod and setting the - /// accessed field. This allows them to be used with the expected `userdata.field = value` syntax. + /// Regular field setters are implemented by overriding the `__newindex` metamethod and setting + /// the accessed field. This allows them to be used with the expected `userdata.field = value` + /// syntax. /// - /// If `add_meta_method` is used to set the `__newindex` metamethod, the `__newindex` metamethod will - /// be used as a fall-back if no regular field is found. - fn add_field_method_set(&mut self, name: &S, method: M) + /// If `add_meta_method` is used to set the `__newindex` metamethod, the `__newindex` metamethod + /// will be used as a fall-back if no regular field is found. + fn add_field_method_set(&mut self, name: impl Into, method: M) where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>; + M: FnMut(&Lua, &mut T, A) -> Result<()> + MaybeSend + 'static, + A: FromLua; /// Add a regular field getter as a function which accepts a generic [`AnyUserData`] of type `T` /// argument. - /// - /// Prefer to use [`add_field_method_get`] as it is easier to use. - /// - /// [`AnyUserData`]: crate::AnyUserData - /// [`add_field_method_get`]: #method.add_field_method_get - fn add_field_function_get(&mut self, name: &S, function: F) + fn add_field_function_get(&mut self, name: impl Into, function: F) where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result; + F: Fn(&Lua, AnyUserData) -> Result + MaybeSend + 'static, + R: IntoLua; /// Add a regular field setter as a function which accepts a generic [`AnyUserData`] of type `T` /// first argument. + fn add_field_function_set(&mut self, name: impl Into, function: F) + where + F: FnMut(&Lua, AnyUserData, A) -> Result<()> + MaybeSend + 'static, + A: FromLua; + + /// Add a metatable field. + /// + /// This will initialize the metatable field with `value` on [`UserData`] creation. /// - /// Prefer to use [`add_field_method_set`] as it is easier to use. + /// # Note /// - /// [`AnyUserData`]: crate::AnyUserData - /// [`add_field_method_set`]: #method.add_field_method_set - fn add_field_function_set(&mut self, name: &S, function: F) + /// `mlua` will trigger an error on an attempt to define a protected metamethod, + /// like `__gc` or `__metatable`. + fn add_meta_field(&mut self, name: impl Into, value: V) where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>; + V: IntoLua + 'static; - /// Add a metamethod value computed from `f`. + /// Add a metatable field computed from `f`. /// - /// This will initialize the metamethod value from `f` on `UserData` creation. + /// This will initialize the metatable field from `f` on [`UserData`] creation. /// /// # Note /// /// `mlua` will trigger an error on an attempt to define a protected metamethod, /// like `__gc` or `__metatable`. - fn add_meta_field_with(&mut self, meta: S, f: F) + fn add_meta_field_with(&mut self, name: impl Into, f: F) where - S: Into, - F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, - R: ToLua<'lua>; - - // - // Below are internal methods used in generated code - // - - #[doc(hidden)] - fn add_field_getter(&mut self, _name: Vec, _callback: Callback<'lua, 'static>) {} - - #[doc(hidden)] - fn add_field_setter(&mut self, _name: Vec, _callback: Callback<'lua, 'static>) {} + F: FnOnce(&Lua) -> Result + 'static, + R: IntoLua; } /// Trait for custom userdata types. /// /// By implementing this trait, a struct becomes eligible for use inside Lua code. -/// Implementation of [`ToLua`] is automatically provided, [`FromLua`] is implemented -/// only for `T: UserData + Clone`. +/// +/// Implementation of [`IntoLua`] is automatically provided, [`FromLua`] needs to be implemented +/// manually. /// /// /// # Examples @@ -603,20 +634,20 @@ pub trait UserDataFields<'lua, T: UserData> { /// # use mlua::{Lua, Result, UserData}; /// # fn main() -> Result<()> { /// # let lua = Lua::new(); -/// struct MyUserData(i32); +/// struct MyUserData; /// /// impl UserData for MyUserData {} /// -/// // `MyUserData` now implements `ToLua`: -/// lua.globals().set("myobject", MyUserData(123))?; +/// // `MyUserData` now implements `IntoLua`: +/// lua.globals().set("myobject", MyUserData)?; /// /// lua.load("assert(type(myobject) == 'userdata')").exec()?; /// # Ok(()) /// # } /// ``` /// -/// Custom fields, methods and operators can be provided by implementing `add_fields` or `add_methods` -/// (refer to [`UserDataFields`] and [`UserDataMethods`] for more information): +/// Custom fields, methods and operators can be provided by implementing `add_fields` or +/// `add_methods` (refer to [`UserDataFields`] and [`UserDataMethods`] for more information): /// /// ``` /// # use mlua::{Lua, MetaMethod, Result, UserData, UserDataFields, UserDataMethods}; @@ -625,12 +656,12 @@ pub trait UserDataFields<'lua, T: UserData> { /// struct MyUserData(i32); /// /// impl UserData for MyUserData { -/// fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { +/// fn add_fields>(fields: &mut F) { /// fields.add_field_method_get("val", |_, this| Ok(this.0)); /// } /// -/// fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { -/// methods.add_method_mut("add", |_, this, value: i32| { +/// fn add_methods>(methods: &mut M) { +/// methods.add_method_mut("add", |_, mut this, value: i32| { /// this.0 += value; /// Ok(()) /// }); @@ -652,297 +683,213 @@ pub trait UserDataFields<'lua, T: UserData> { /// # Ok(()) /// # } /// ``` -/// -/// [`ToLua`]: crate::ToLua -/// [`FromLua`]: crate::FromLua -/// [`UserDataFields`]: crate::UserDataFields -/// [`UserDataMethods`]: crate::UserDataMethods pub trait UserData: Sized { /// Adds custom fields specific to this userdata. - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(_fields: &mut F) {} + #[allow(unused_variables)] + fn add_fields>(fields: &mut F) {} /// Adds custom methods and operators specific to this userdata. - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(_methods: &mut M) {} -} - -// Wraps UserData in a way to always implement `serde::Serialize` trait. -pub(crate) struct UserDataCell(RefCell>); - -impl UserDataCell { - #[inline] - pub(crate) fn new(data: T) -> Self { - UserDataCell(RefCell::new(UserDataWrapped::new(data))) - } - - #[cfg(feature = "serialize")] - #[inline] - pub(crate) fn new_ser(data: T) -> Self - where - T: 'static + Serialize, - { - UserDataCell(RefCell::new(UserDataWrapped::new_ser(data))) - } - - // Immutably borrows the wrapped value. - #[inline] - pub(crate) fn try_borrow(&self) -> Result> { - self.0 - .try_borrow() - .map(|r| Ref::map(r, |r| r.deref())) - .map_err(|_| Error::UserDataBorrowError) - } - - // Mutably borrows the wrapped value. - #[inline] - pub(crate) fn try_borrow_mut(&self) -> Result> { - self.0 - .try_borrow_mut() - .map(|r| RefMut::map(r, |r| r.deref_mut())) - .map_err(|_| Error::UserDataBorrowMutError) - } - - // Consumes this `UserDataCell`, returning the wrapped value. - #[inline] - fn into_inner(self) -> T { - self.0.into_inner().into_inner() - } -} - -pub(crate) enum UserDataWrapped { - Default(Box), - #[cfg(feature = "serialize")] - Serializable(Box), -} + #[allow(unused_variables)] + fn add_methods>(methods: &mut M) {} -impl UserDataWrapped { - #[inline] - fn new(data: T) -> Self { - UserDataWrapped::Default(Box::new(data)) - } - - #[cfg(feature = "serialize")] - #[inline] - fn new_ser(data: T) -> Self - where - T: 'static + Serialize, - { - UserDataWrapped::Serializable(Box::new(data)) - } - - #[inline] - fn into_inner(self) -> T { - match self { - Self::Default(data) => *data, - #[cfg(feature = "serialize")] - Self::Serializable(data) => unsafe { *Box::from_raw(Box::into_raw(data) as *mut T) }, - } - } -} - -impl Deref for UserDataWrapped { - type Target = T; - - #[inline] - fn deref(&self) -> &Self::Target { - match self { - Self::Default(data) => data, - #[cfg(feature = "serialize")] - Self::Serializable(data) => unsafe { - &*(data.as_ref() as *const _ as *const Self::Target) - }, - } - } -} - -impl DerefMut for UserDataWrapped { - #[inline] - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - Self::Default(data) => data, - #[cfg(feature = "serialize")] - Self::Serializable(data) => unsafe { - &mut *(data.as_mut() as *mut _ as *mut Self::Target) - }, - } - } -} - -#[cfg(feature = "serialize")] -struct UserDataSerializeError; - -#[cfg(feature = "serialize")] -impl Serialize for UserDataSerializeError { - fn serialize(&self, _serializer: S) -> StdResult - where - S: Serializer, - { - Err(ser::Error::custom("cannot serialize ")) + /// Registers this type for use in Lua. + /// + /// This method is responsible for calling `add_fields` and `add_methods` on the provided + /// [`UserDataRegistry`]. + fn register(registry: &mut UserDataRegistry) { + Self::add_fields(registry); + Self::add_methods(registry); } } /// Handle to an internal Lua userdata for any type that implements [`UserData`]. /// -/// Similar to `std::any::Any`, this provides an interface for dynamic type checking via the [`is`] -/// and [`borrow`] methods. -/// -/// Internally, instances are stored in a `RefCell`, to best match the mutable semantics of the Lua -/// language. +/// Similar to [`std::any::Any`], this provides an interface for dynamic type checking via the +/// [`is`] and [`borrow`] methods. /// /// # Note /// /// This API should only be used when necessary. Implementing [`UserData`] already allows defining /// methods which check the type and acquire a borrow behind the scenes. /// -/// [`UserData`]: crate::UserData /// [`is`]: crate::AnyUserData::is /// [`borrow`]: crate::AnyUserData::borrow -#[derive(Clone, Debug)] -pub struct AnyUserData<'lua>(pub(crate) LuaRef<'lua>); +#[derive(Clone, PartialEq)] +pub struct AnyUserData(pub(crate) ValueRef); -impl<'lua> AnyUserData<'lua> { +impl AnyUserData { /// Checks whether the type of this userdata is `T`. - pub fn is(&self) -> bool { - match self.inspect(|_: &UserDataCell| Ok(())) { - Ok(()) => true, - Err(Error::UserDataTypeMismatch) => false, - Err(_) => unreachable!(), - } + #[inline] + pub fn is(&self) -> bool { + let type_id = self.type_id(); + // We do not use wrapped types here, rather prefer to check the "real" type of the userdata + matches!(type_id, Some(type_id) if type_id == TypeId::of::()) + } + + /// Checks whether the type of this userdata is a [proxy object] for `T`. + /// + /// [proxy object]: crate::Lua::create_proxy + #[inline] + pub fn is_proxy(&self) -> bool { + self.is::>() } /// Borrow this userdata immutably if it is of type `T`. /// /// # Errors /// - /// Returns a `UserDataBorrowError` if the userdata is already mutably borrowed. Returns a - /// `UserDataTypeMismatch` if the userdata is not of type `T`. + /// Returns a [`UserDataBorrowError`] if the userdata is already mutably borrowed. + /// Returns a [`DataTypeMismatch`] if the userdata is not of type `T` or if it's + /// scoped. + /// + /// [`UserDataBorrowError`]: crate::Error::UserDataBorrowError + /// [`DataTypeMismatch`]: crate::Error::UserDataTypeMismatch #[inline] - pub fn borrow(&self) -> Result> { - self.inspect(|cell| cell.try_borrow()) + pub fn borrow(&self) -> Result> { + let lua = self.0.lua.lock(); + unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + } + + /// Borrow this userdata immutably if it is of type `T`, passing the borrowed value + /// to the closure. + /// + /// This method is the only way to borrow scoped userdata (created inside [`Lua::scope`]). + pub fn borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { + let lua = self.0.lua.lock(); + let type_id = lua.get_userdata_ref_type_id(&self.0)?; + let type_hints = TypeIdHints::new::(); + unsafe { borrow_userdata_scoped(lua.ref_thread(), self.0.index, type_id, type_hints, f) } } /// Borrow this userdata mutably if it is of type `T`. /// /// # Errors /// - /// Returns a `UserDataBorrowMutError` if the userdata cannot be mutably borrowed. - /// Returns a `UserDataTypeMismatch` if the userdata is not of type `T`. + /// Returns a [`UserDataBorrowMutError`] if the userdata cannot be mutably borrowed. + /// Returns a [`UserDataTypeMismatch`] if the userdata is not of type `T` or if it's + /// scoped. + /// + /// [`UserDataBorrowMutError`]: crate::Error::UserDataBorrowMutError + /// [`UserDataTypeMismatch`]: crate::Error::UserDataTypeMismatch #[inline] - pub fn borrow_mut(&self) -> Result> { - self.inspect(|cell| cell.try_borrow_mut()) + pub fn borrow_mut(&self) -> Result> { + let lua = self.0.lua.lock(); + unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } + } + + /// Borrow this userdata mutably if it is of type `T`, passing the borrowed value + /// to the closure. + /// + /// This method is the only way to borrow scoped userdata (created inside [`Lua::scope`]). + pub fn borrow_mut_scoped(&self, f: impl FnOnce(&mut T) -> R) -> Result { + let lua = self.0.lua.lock(); + let type_id = lua.get_userdata_ref_type_id(&self.0)?; + let type_hints = TypeIdHints::new::(); + unsafe { borrow_userdata_scoped_mut(lua.ref_thread(), self.0.index, type_id, type_hints, f) } + } + + /// Takes the value out of this userdata. + /// + /// Sets the special "destructed" metatable that prevents any further operations with this + /// userdata. + /// + /// Keeps associated user values unchanged (they will be collected by Lua's GC). + pub fn take(&self) -> Result { + let lua = self.0.lua.lock(); + match lua.get_userdata_ref_type_id(&self.0)? { + Some(type_id) if type_id == TypeId::of::() => unsafe { + let ref_thread = lua.ref_thread(); + if (*get_userdata::>(ref_thread, self.0.index)).has_exclusive_access() { + take_userdata::>(ref_thread, self.0.index).into_inner() + } else { + Err(Error::UserDataBorrowMutError) + } + }, + _ => Err(Error::UserDataTypeMismatch), + } } - /// Takes out the value of `UserData` and sets the special "destructed" metatable that prevents - /// any further operations with this userdata. + /// Destroys this userdata. /// - /// All associated user values will be also cleared. - pub fn take(&self) -> Result { - let lua = self.0.lua; + /// This is similar to [`AnyUserData::take`], but it doesn't require a type. + /// + /// This method works for non-scoped userdata only. + pub fn destroy(&self) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 3)?; - - let type_id = lua.push_userdata_ref(&self.0)?; - match type_id { - Some(type_id) if type_id == TypeId::of::() => { - // Try to borrow userdata exclusively - let _ = (*get_userdata::>(lua.state, -1)).try_borrow_mut()?; - - // Clear associated user values - #[cfg(feature = "lua54")] - for i in 1..=USER_VALUE_MAXSLOT { - ffi::lua_pushnil(lua.state); - ffi::lua_setiuservalue(lua.state, -2, i as c_int); - } - #[cfg(any(feature = "lua53", feature = "lua52", feature = "luau"))] - { - ffi::lua_pushnil(lua.state); - ffi::lua_setuservalue(lua.state, -2); - } - #[cfg(any(feature = "lua51", feature = "luajit"))] - protect_lua!(lua.state, 1, 1, fn(state) { - ffi::lua_newtable(state); - ffi::lua_setuservalue(state, -2); - })?; + let _sg = StackGuard::new(state); + check_stack(state, 3)?; - Ok(take_userdata::>(lua.state).into_inner()) + lua.push_userdata_ref(&self.0)?; + protect_lua!(state, 1, 1, fn(state) { + if ffi::luaL_callmeta(state, -1, cstr!("__gc")) == 0 { + ffi::lua_pushboolean(state, 0); } - _ => Err(Error::UserDataTypeMismatch), + })?; + if ffi::lua_isboolean(state, -1) != 0 && ffi::lua_toboolean(state, -1) != 0 { + return Ok(()); } + Err(Error::UserDataBorrowMutError) } } - /// Sets an associated value to this `AnyUserData`. + /// Sets an associated value to this [`AnyUserData`]. /// - /// The value may be any Lua value whatsoever, and can be retrieved with [`get_user_value`]. + /// The value may be any Lua value whatsoever, and can be retrieved with [`user_value`]. /// /// This is the same as calling [`set_nth_user_value`] with `n` set to 1. /// - /// [`get_user_value`]: #method.get_user_value - /// [`set_nth_user_value`]: #method.set_nth_user_value + /// [`user_value`]: AnyUserData::user_value + /// [`set_nth_user_value`]: AnyUserData::set_nth_user_value #[inline] - pub fn set_user_value>(&self, v: V) -> Result<()> { + pub fn set_user_value(&self, v: impl IntoLua) -> Result<()> { self.set_nth_user_value(1, v) } /// Returns an associated value set by [`set_user_value`]. /// - /// This is the same as calling [`get_nth_user_value`] with `n` set to 1. + /// This is the same as calling [`nth_user_value`] with `n` set to 1. /// - /// [`set_user_value`]: #method.set_user_value - /// [`get_nth_user_value`]: #method.get_nth_user_value + /// [`set_user_value`]: AnyUserData::set_user_value + /// [`nth_user_value`]: AnyUserData::nth_user_value #[inline] - pub fn get_user_value>(&self) -> Result { - self.get_nth_user_value(1) + pub fn user_value(&self) -> Result { + self.nth_user_value(1) } - /// Sets an associated `n`th value to this `AnyUserData`. + /// Sets an associated `n`th value to this [`AnyUserData`]. /// - /// The value may be any Lua value whatsoever, and can be retrieved with [`get_nth_user_value`]. + /// The value may be any Lua value whatsoever, and can be retrieved with [`nth_user_value`]. /// `n` starts from 1 and can be up to 65535. /// - /// This is supported for all Lua versions. - /// In Lua 5.4 first 7 elements are stored in a most efficient way. - /// For other Lua versions this functionality is provided using a wrapping table. + /// This is supported for all Lua versions using a wrapping table. /// - /// [`get_nth_user_value`]: #method.get_nth_user_value - pub fn set_nth_user_value>(&self, n: usize, v: V) -> Result<()> { + /// [`nth_user_value`]: AnyUserData::nth_user_value + pub fn set_nth_user_value(&self, n: usize, v: impl IntoLua) -> Result<()> { if n < 1 || n > u16::MAX as usize { - return Err(Error::RuntimeError( - "user value index out of bounds".to_string(), - )); + return Err(Error::runtime("user value index out of bounds")); } - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_userdata_ref(&self.0)?; - lua.push_value(v.to_lua(lua)?)?; - - #[cfg(feature = "lua54")] - if n < USER_VALUE_MAXSLOT { - ffi::lua_setiuservalue(lua.state, -2, n as c_int); - return Ok(()); - } + lua.push(v)?; // Multiple (extra) user values are emulated by storing them in a table - protect_lua!(lua.state, 2, 0, |state| { - if getuservalue_table(state, -2) != ffi::LUA_TTABLE { + protect_lua!(state, 2, 0, |state| { + if ffi::lua_getuservalue(state, -2) != ffi::LUA_TTABLE { // Create a new table to use as uservalue ffi::lua_pop(state, 1); ffi::lua_newtable(state); ffi::lua_pushvalue(state, -1); - - #[cfg(feature = "lua54")] - ffi::lua_setiuservalue(state, -4, USER_VALUE_MAXSLOT as c_int); - #[cfg(not(feature = "lua54"))] ffi::lua_setuservalue(state, -4); } ffi::lua_pushvalue(state, -2); - #[cfg(feature = "lua54")] - ffi::lua_rawseti(state, -2, (n - USER_VALUE_MAXSLOT + 1) as ffi::lua_Integer); - #[cfg(not(feature = "lua54"))] ffi::lua_rawseti(state, -2, n as ffi::lua_Integer); })?; @@ -954,77 +901,54 @@ impl<'lua> AnyUserData<'lua> { /// /// `n` starts from 1 and can be up to 65535. /// - /// This is supported for all Lua versions. - /// In Lua 5.4 first 7 elements are stored in a most efficient way. - /// For other Lua versions this functionality is provided using a wrapping table. + /// This is supported for all Lua versions using a wrapping table. /// - /// [`set_nth_user_value`]: #method.set_nth_user_value - pub fn get_nth_user_value>(&self, n: usize) -> Result { + /// [`set_nth_user_value`]: AnyUserData::set_nth_user_value + pub fn nth_user_value(&self, n: usize) -> Result { if n < 1 || n > u16::MAX as usize { - return Err(Error::RuntimeError( - "user value index out of bounds".to_string(), - )); + return Err(Error::runtime("user value index out of bounds")); } - let lua = self.0.lua; + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_userdata_ref(&self.0)?; - #[cfg(feature = "lua54")] - if n < USER_VALUE_MAXSLOT { - ffi::lua_getiuservalue(lua.state, -1, n as c_int); - return V::from_lua(lua.pop_value(), lua); - } - // Multiple (extra) user values are emulated by storing them in a table - protect_lua!(lua.state, 1, 1, |state| { - if getuservalue_table(state, -1) != ffi::LUA_TTABLE { - ffi::lua_pushnil(state); - return; - } - #[cfg(feature = "lua54")] - ffi::lua_rawgeti(state, -1, (n - USER_VALUE_MAXSLOT + 1) as ffi::lua_Integer); - #[cfg(not(feature = "lua54"))] - ffi::lua_rawgeti(state, -1, n as ffi::lua_Integer); - })?; + if ffi::lua_getuservalue(state, -1) != ffi::LUA_TTABLE { + return V::from_lua(Value::Nil, lua.lua()); + } + ffi::lua_rawgeti(state, -1, n as ffi::lua_Integer); - V::from_lua(lua.pop_value(), lua) + V::from_lua(lua.pop_value(), lua.lua()) } } - /// Sets an associated value to this `AnyUserData` by name. + /// Sets an associated value to this [`AnyUserData`] by name. /// - /// The value can be retrieved with [`get_named_user_value`]. + /// The value can be retrieved with [`named_user_value`]. /// - /// [`get_named_user_value`]: #method.get_named_user_value - pub fn set_named_user_value(&self, name: &S, v: V) -> Result<()> - where - S: AsRef<[u8]> + ?Sized, - V: ToLua<'lua>, - { - let lua = self.0.lua; + /// [`named_user_value`]: AnyUserData::named_user_value + pub fn set_named_user_value(&self, name: &str, v: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 5)?; + let _sg = StackGuard::new(state); + check_stack(state, 5)?; lua.push_userdata_ref(&self.0)?; - lua.push_value(v.to_lua(lua)?)?; + lua.push(v)?; // Multiple (extra) user values are emulated by storing them in a table - let name = name.as_ref(); - protect_lua!(lua.state, 2, 0, |state| { - if getuservalue_table(state, -2) != ffi::LUA_TTABLE { + protect_lua!(state, 2, 0, |state| { + if ffi::lua_getuservalue(state, -2) != ffi::LUA_TTABLE { // Create a new table to use as uservalue ffi::lua_pop(state, 1); ffi::lua_newtable(state); ffi::lua_pushvalue(state, -1); - - #[cfg(feature = "lua54")] - ffi::lua_setiuservalue(state, -4, USER_VALUE_MAXSLOT as c_int); - #[cfg(not(feature = "lua54"))] ffi::lua_setuservalue(state, -4); } ffi::lua_pushlstring(state, name.as_ptr() as *const c_char, name.len()); @@ -1038,186 +962,235 @@ impl<'lua> AnyUserData<'lua> { /// Returns an associated value by name set by [`set_named_user_value`]. /// - /// [`set_named_user_value`]: #method.set_named_user_value - pub fn get_named_user_value(&self, name: &S) -> Result - where - S: AsRef<[u8]> + ?Sized, - V: FromLua<'lua>, - { - let lua = self.0.lua; + /// [`set_named_user_value`]: AnyUserData::set_named_user_value + pub fn named_user_value(&self, name: &str) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 4)?; + let _sg = StackGuard::new(state); + check_stack(state, 4)?; lua.push_userdata_ref(&self.0)?; // Multiple (extra) user values are emulated by storing them in a table - let name = name.as_ref(); - protect_lua!(lua.state, 1, 1, |state| { - if getuservalue_table(state, -1) != ffi::LUA_TTABLE { - ffi::lua_pushnil(state); - return; - } - ffi::lua_pushlstring(state, name.as_ptr() as *const c_char, name.len()); - ffi::lua_rawget(state, -2); - })?; + if ffi::lua_getuservalue(state, -1) != ffi::LUA_TTABLE { + return V::from_lua(Value::Nil, lua.lua()); + } + push_string(state, name.as_bytes(), !lua.unlikely_memory_error())?; + ffi::lua_rawget(state, -2); - V::from_lua(lua.pop_value(), lua) + V::from_stack(-1, &lua) } } - /// Returns a metatable of this `UserData`. + /// Returns a metatable of this [`AnyUserData`]. /// /// Returned [`UserDataMetatable`] object wraps the original metatable and /// provides safe access to its methods. /// - /// For `T: UserData + 'static` returned metatable is shared among all instances of type `T`. + /// For `T: 'static` returned metatable is shared among all instances of type `T`. + #[inline] + pub fn metatable(&self) -> Result { + self.raw_metatable().map(UserDataMetatable) + } + + /// Returns a raw metatable of this [`AnyUserData`]. + fn raw_metatable(&self) -> Result
{ + let lua = self.0.lua.lock(); + let ref_thread = lua.ref_thread(); + unsafe { + // Check that userdata is registered and not destructed + // All registered userdata types have a non-empty metatable + let _type_id = lua.get_userdata_ref_type_id(&self.0)?; + + ffi::lua_getmetatable(ref_thread, self.0.index); + Ok(Table(lua.pop_ref_thread())) + } + } + + /// Converts this userdata to a generic C pointer. /// - /// [`UserDataMetatable`]: crate::UserDataMetatable - pub fn get_metatable(&self) -> Result> { - self.get_raw_metatable().map(UserDataMetatable) + /// There is no way to convert the pointer back to its original value. + /// + /// Typically this function is used only for hashing and debug information. + #[inline] + pub fn to_pointer(&self) -> *const c_void { + self.0.to_pointer() + } + + /// Returns [`TypeId`] of this userdata if it is registered and `'static`. + /// + /// This method is not available for scoped userdata. + #[inline] + pub fn type_id(&self) -> Option { + let lua = self.0.lua.lock(); + lua.get_userdata_ref_type_id(&self.0).ok().flatten() } - fn get_raw_metatable(&self) -> Result> { + /// Returns a type name of this userdata (from a metatable field). + /// + /// If no type name is set, returns `userdata`. + pub fn type_name(&self) -> Result { + let lua = self.0.lua.lock(); + let state = lua.state(); unsafe { - let lua = self.0.lua; - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 3)?; + let _sg = StackGuard::new(state); + check_stack(state, 3)?; lua.push_userdata_ref(&self.0)?; - ffi::lua_getmetatable(lua.state, -1); // Checked that non-empty on the previous call - Ok(Table(lua.pop_ref())) + let protect = !lua.unlikely_memory_error(); + let name_type = if protect { + protect_lua!(state, 1, 1, |state| { + ffi::luaL_getmetafield(state, -1, MetaMethod::Type.as_cstr().as_ptr()) + })? + } else { + ffi::luaL_getmetafield(state, -1, MetaMethod::Type.as_cstr().as_ptr()) + }; + match name_type { + ffi::LUA_TSTRING => Ok(LuaString(lua.pop_ref())), + _ => lua.create_string(b"userdata"), + } } } - pub(crate) fn equals>(&self, other: T) -> Result { - let other = other.as_ref(); + pub(crate) fn equals(&self, other: &Self) -> Result { // Uses lua_rawequal() under the hood if self == other { return Ok(true); } - let mt = self.get_raw_metatable()?; - if mt != other.get_raw_metatable()? { + let mt = self.raw_metatable()?; + if mt != other.raw_metatable()? { return Ok(false); } if mt.contains_key("__eq")? { - return mt - .get::<_, Function>("__eq")? - .call((self.clone(), other.clone())); + return mt.get::("__eq")?.call((self, other)); } Ok(false) } - fn inspect<'a, T, R, F>(&'a self, func: F) -> Result - where - T: 'static + UserData, - F: FnOnce(&'a UserDataCell) -> Result, - { - let lua = self.0.lua; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; + /// Returns `true` if this [`AnyUserData`] is serializable (e.g. was created using + /// [`Lua::create_ser_userdata`]). + #[cfg(feature = "serde")] + pub(crate) fn is_serializable(&self) -> bool { + let lua = self.0.lua.lock(); + let is_serializable = || unsafe { + // Userdata must be registered and not destructed + let _ = lua.get_userdata_ref_type_id(&self.0)?; + let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + Ok::<_, Error>((*ud).is_serializable()) + }; + is_serializable().unwrap_or(false) + } - let type_id = lua.push_userdata_ref(&self.0)?; - match type_id { - Some(type_id) if type_id == TypeId::of::() => { - func(&*get_userdata::>(lua.state, -1)) + unsafe fn invoke_tostring_dbg(&self) -> Result> { + let lua = self.0.lua.lock(); + let state = lua.state(); + let _guard = StackGuard::new(state); + check_stack(state, 3)?; + + lua.push_ref(&self.0); + protect_lua!(state, 1, 1, fn(state) { + // Try `__todebugstring` metamethod first, then `__tostring` + #[allow(clippy::collapsible_if)] + if ffi::luaL_callmeta(state, -1, cstr!("__todebugstring")) == 0 { + if ffi::luaL_callmeta(state, -1, cstr!("__tostring")) == 0 { + ffi::lua_pushnil(state); } - _ => Err(Error::UserDataTypeMismatch), } - } + })?; + Ok(lua.pop_value().as_string().map(|s| s.to_string_lossy())) } -} -impl<'lua> PartialEq for AnyUserData<'lua> { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + pub(crate) fn fmt_pretty(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + // Try converting to a (debug) string first, with fallback to `__name/__type` + match unsafe { self.invoke_tostring_dbg() } { + Ok(Some(s)) => write!(fmt, "{s}"), + _ => { + let name = self.type_name().ok(); + let name = (name.as_ref()) + .map(|s| Either::Left(s.display())) + .unwrap_or(Either::Right("userdata")); + write!(fmt, "{name}: {:?}", self.to_pointer()) + } + } } } -impl<'lua> AsRef> for AnyUserData<'lua> { - #[inline] - fn as_ref(&self) -> &Self { - self +impl fmt::Debug for AnyUserData { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + if fmt.alternate() { + return self.fmt_pretty(fmt); + } + fmt.debug_tuple("AnyUserData").field(&self.0).finish() } } -unsafe fn getuservalue_table(state: *mut ffi::lua_State, idx: c_int) -> c_int { - #[cfg(feature = "lua54")] - return ffi::lua_getiuservalue(state, idx, USER_VALUE_MAXSLOT as c_int); - #[cfg(not(feature = "lua54"))] - return ffi::lua_getuservalue(state, idx); -} - -/// Handle to a `UserData` metatable. +/// Handle to a [`AnyUserData`] metatable. #[derive(Clone, Debug)] -pub struct UserDataMetatable<'lua>(pub(crate) Table<'lua>); +pub struct UserDataMetatable(pub(crate) Table); -impl<'lua> UserDataMetatable<'lua> { +impl UserDataMetatable { /// Gets the value associated to `key` from the metatable. /// /// If no value is associated to `key`, returns the `Nil` value. /// Access to restricted metamethods such as `__gc` or `__metatable` will cause an error. - pub fn get, V: FromLua<'lua>>(&self, key: K) -> Result { - self.0.raw_get(key.into().validate()?.name()) + pub fn get(&self, key: impl AsRef) -> Result { + self.0.raw_get(MetaMethod::validate(key.as_ref())?) } /// Sets a key-value pair in the metatable. /// /// If the value is `Nil`, this will effectively remove the `key`. /// Access to restricted metamethods such as `__gc` or `__metatable` will cause an error. - /// Setting `__index` or `__newindex` metamethods is also restricted because their values are cached - /// for `mlua` internal usage. - pub fn set, V: ToLua<'lua>>(&self, key: K, value: V) -> Result<()> { - let key = key.into().validate()?; + /// Setting `__index` or `__newindex` metamethods is also restricted because their values are + /// cached for `mlua` internal usage. + pub fn set(&self, key: impl AsRef, value: impl IntoLua) -> Result<()> { + let key = MetaMethod::validate(key.as_ref())?; // `__index` and `__newindex` cannot be changed in runtime, because values are cached if key == MetaMethod::Index || key == MetaMethod::NewIndex { return Err(Error::MetaMethodRestricted(key.to_string())); } - self.0.raw_set(key.name(), value) + self.0.raw_set(key, value) } /// Checks whether the metatable contains a non-nil value for `key`. - pub fn contains>(&self, key: K) -> Result { - self.0.contains_key(key.into().validate()?.name()) + pub fn contains(&self, key: impl AsRef) -> Result { + self.0.contains_key(MetaMethod::validate(key.as_ref())?) } - /// Consumes this metatable and returns an iterator over the pairs of the metatable. + /// Returns an iterator over the pairs of the metatable. /// /// The pairs are wrapped in a [`Result`], since they are lazily converted to `V` type. /// /// [`Result`]: crate::Result - pub fn pairs>(self) -> UserDataMetatablePairs<'lua, V> { + pub fn pairs(&self) -> UserDataMetatablePairs<'_, V> { UserDataMetatablePairs(self.0.pairs()) } } -/// An iterator over the pairs of a [`UserData`] metatable. +/// An iterator over the pairs of a [`AnyUserData`] metatable. /// /// It skips restricted metamethods, such as `__gc` or `__metatable`. /// /// This struct is created by the [`UserDataMetatable::pairs`] method. -/// -/// [`UserData`]: crate::UserData -/// [`UserDataMetatable::pairs`]: crate::UserDataMetatable::method.pairs -pub struct UserDataMetatablePairs<'lua, V>(TablePairs<'lua, StdString, V>); +pub struct UserDataMetatablePairs<'a, V>(TablePairs<'a, String, V>); -impl<'lua, V> Iterator for UserDataMetatablePairs<'lua, V> +impl Iterator for UserDataMetatablePairs<'_, V> where - V: FromLua<'lua>, + V: FromLua, { - type Item = Result<(MetaMethod, V)>; + type Item = Result<(String, V)>; fn next(&mut self) -> Option { loop { match self.0.next()? { Ok((key, value)) => { // Skip restricted metamethods - if let Ok(metamethod) = MetaMethod::from(key).validate() { - break Some(Ok((metamethod, value))); + if MetaMethod::validate(&key).is_ok() { + break Some(Ok((key, value))); } } Err(e) => break Some(Err(e)), @@ -1226,25 +1199,66 @@ where } } -#[cfg(feature = "serialize")] -impl<'lua> Serialize for AnyUserData<'lua> { +#[cfg(feature = "serde")] +impl Serialize for AnyUserData { fn serialize(&self, serializer: S) -> StdResult where S: Serializer, { - let lua = self.0.lua; - let data = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 3).map_err(ser::Error::custom)?; - - lua.push_userdata_ref(&self.0).map_err(ser::Error::custom)?; - let ud = &*get_userdata::>(lua.state, -1); - ud.0.try_borrow() - .map_err(|_| ser::Error::custom(Error::UserDataBorrowError))? - }; - match &*data { - UserDataWrapped::Default(_) => UserDataSerializeError.serialize(serializer), - UserDataWrapped::Serializable(ser) => ser.serialize(serializer), + let lua = self.0.lua.lock(); + unsafe { + let _ = lua + .get_userdata_ref_type_id(&self.0) + .map_err(ser::Error::custom)?; + let ud = &*get_userdata::>(lua.ref_thread(), self.0.index); + ud.serialize(serializer) } } } + +struct WrappedUserdata Result>(F); + +impl AnyUserData { + /// Wraps any Rust type, returning an opaque type that implements [`IntoLua`] trait. + /// + /// This function uses [`Lua::create_any_userdata`] under the hood. + pub fn wrap(data: T) -> impl IntoLua { + WrappedUserdata(move |lua| lua.create_any_userdata(data)) + } + + /// Wraps any Rust type that implements [`Serialize`], returning an opaque type that implements + /// [`IntoLua`] trait. + /// + /// This function uses [`Lua::create_ser_any_userdata`] under the hood. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + pub fn wrap_ser(data: T) -> impl IntoLua { + WrappedUserdata(move |lua| lua.create_ser_any_userdata(data)) + } +} + +impl IntoLua for WrappedUserdata +where + F: for<'l> FnOnce(&'l Lua) -> Result, +{ + fn into_lua(self, lua: &Lua) -> Result { + (self.0)(lua).map(Value::UserData) + } +} + +mod cell; +mod lock; +mod object; +mod r#ref; +mod registry; +mod util; + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(AnyUserData: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(AnyUserData: Send, Sync); +} diff --git a/src/userdata/cell.rs b/src/userdata/cell.rs new file mode 100644 index 00000000..f0058fd8 --- /dev/null +++ b/src/userdata/cell.rs @@ -0,0 +1,266 @@ +use std::cell::RefCell; + +#[cfg(feature = "serde")] +use serde::ser::{Serialize, Serializer}; + +use crate::error::{Error, Result}; +use crate::types::XRc; + +use super::lock::{RawLock, RwLock, UserDataLock}; +use super::r#ref::{UserDataRef, UserDataRefMut}; + +#[cfg(all(feature = "serde", not(feature = "send")))] +type DynSerialize = dyn erased_serde::Serialize; + +#[cfg(all(feature = "serde", feature = "send"))] +type DynSerialize = dyn erased_serde::Serialize + Send + Sync; + +pub(crate) enum UserDataStorage { + Owned(UserDataVariant), + Scoped(ScopedUserDataVariant), +} + +// A enum for storing userdata values. +// It's stored inside a Lua VM and protected by the outer `ReentrantMutex`. +pub(crate) enum UserDataVariant { + Default(XRc>), + #[cfg(feature = "serde")] + Serializable(XRc>>), +} + +impl Clone for UserDataVariant { + #[inline] + fn clone(&self) -> Self { + match self { + Self::Default(inner) => Self::Default(XRc::clone(inner)), + #[cfg(feature = "serde")] + Self::Serializable(inner) => Self::Serializable(XRc::clone(inner)), + } + } +} + +impl UserDataVariant { + #[inline(always)] + pub(super) fn try_borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { + // Shared (read) lock is always correct for in-place borrows: + // - this method is called internally while the Lua mutex is held, ensuring exclusive Lua-level + // access per call frame + // - with `send` feature, all owned userdata satisfies `T: Sync`, so simultaneous shared references + // from multiple threads are sound + // - without `send` feature, single-threaded execution makes shared lock safe for any `T` + let _guard = (self.raw_lock().try_lock_shared_guarded()).map_err(|_| Error::UserDataBorrowError)?; + Ok(f(unsafe { &*self.as_ptr() })) + } + + // Mutably borrows the wrapped value in-place. + #[inline(always)] + fn try_borrow_scoped_mut(&self, f: impl FnOnce(&mut T) -> R) -> Result { + let _guard = + (self.raw_lock().try_lock_exclusive_guarded()).map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(unsafe { &mut *self.as_ptr() })) + } + + // Immutably borrows the wrapped value and returns an owned reference. + #[inline(always)] + fn try_borrow_owned(&self) -> Result> { + UserDataRef::try_from(self.clone()) + } + + // Mutably borrows the wrapped value and returns an owned reference. + #[inline(always)] + fn try_borrow_owned_mut(&self) -> Result> { + UserDataRefMut::try_from(self.clone()) + } + + // Returns the wrapped value. + // + // This method checks that we have exclusive access to the value. + fn into_inner(self) -> Result { + if !self.raw_lock().try_lock_exclusive() { + return Err(Error::UserDataBorrowMutError); + } + Ok(match self { + Self::Default(inner) => XRc::into_inner(inner).unwrap().into_inner(), + #[cfg(feature = "serde")] + Self::Serializable(inner) => unsafe { + // The serde variant erases `T` to `Box`, so we + // must cast the raw pointer back to recover the concrete type. + let raw = Box::into_raw(XRc::into_inner(inner).unwrap().into_inner()); + *Box::from_raw(raw as *mut T) + }, + }) + } + + #[inline(always)] + fn strong_count(&self) -> usize { + match self { + Self::Default(inner) => XRc::strong_count(inner), + #[cfg(feature = "serde")] + Self::Serializable(inner) => XRc::strong_count(inner), + } + } + + #[inline(always)] + pub(super) fn raw_lock(&self) -> &RawLock { + match self { + Self::Default(inner) => unsafe { inner.raw() }, + #[cfg(feature = "serde")] + Self::Serializable(inner) => unsafe { inner.raw() }, + } + } + + #[inline(always)] + pub(super) fn as_ptr(&self) -> *mut T { + match self { + Self::Default(inner) => inner.data_ptr(), + #[cfg(feature = "serde")] + Self::Serializable(inner) => unsafe { (&mut **inner.data_ptr()) as *mut DynSerialize as *mut T }, + } + } +} + +#[cfg(feature = "serde")] +impl Serialize for UserDataStorage<()> { + fn serialize(&self, serializer: S) -> std::result::Result { + match self { + Self::Owned(variant @ UserDataVariant::Serializable(inner)) => unsafe { + let _guard = (variant.raw_lock().try_lock_shared_guarded()) + .map_err(|_| serde::ser::Error::custom(Error::UserDataBorrowError))?; + (*inner.data_ptr()).serialize(serializer) + }, + _ => Err(serde::ser::Error::custom("cannot serialize ")), + } + } +} + +pub(crate) enum ScopedUserDataVariant { + Ref(*const T), + RefMut(RefCell<*mut T>), + Boxed(RefCell<*mut T>), +} + +impl Drop for ScopedUserDataVariant { + #[inline] + fn drop(&mut self) { + if let Self::Boxed(value) = self + && let Ok(value) = value.try_borrow_mut() + { + unsafe { drop(Box::from_raw(*value)) } + } + } +} + +impl UserDataStorage { + #[inline(always)] + pub(crate) fn new(data: T) -> Self { + Self::Owned(UserDataVariant::Default(XRc::new(RwLock::new(data)))) + } + + #[inline(always)] + pub(crate) fn new_ref(data: &T) -> Self { + Self::Scoped(ScopedUserDataVariant::Ref(data)) + } + + #[inline(always)] + pub(crate) fn new_ref_mut(data: &mut T) -> Self { + Self::Scoped(ScopedUserDataVariant::RefMut(RefCell::new(data))) + } + + #[cfg(feature = "serde")] + #[inline(always)] + pub(crate) fn new_ser(data: T) -> Self + where + T: Serialize + crate::types::MaybeSend + crate::types::MaybeSync, + { + let data = Box::new(data) as Box; + let variant = UserDataVariant::Serializable(XRc::new(RwLock::new(data))); + Self::Owned(variant) + } + + #[cfg(feature = "serde")] + #[inline(always)] + pub(crate) fn is_serializable(&self) -> bool { + matches!(self, Self::Owned(UserDataVariant::Serializable(..))) + } + + // Immutably borrows the wrapped value and returns an owned reference. + #[inline(always)] + pub(crate) fn try_borrow_owned(&self) -> Result> { + match self { + Self::Owned(data) => data.try_borrow_owned(), + Self::Scoped(_) => Err(Error::UserDataTypeMismatch), + } + } + + // Mutably borrows the wrapped value and returns an owned reference. + #[inline(always)] + pub(crate) fn try_borrow_owned_mut(&self) -> Result> { + match self { + Self::Owned(data) => data.try_borrow_owned_mut(), + Self::Scoped(_) => Err(Error::UserDataTypeMismatch), + } + } + + #[inline(always)] + pub(crate) fn into_inner(self) -> Result { + match self { + Self::Owned(data) => data.into_inner(), + Self::Scoped(_) => Err(Error::UserDataTypeMismatch), + } + } +} + +impl UserDataStorage { + #[inline(always)] + pub(crate) fn new_scoped(data: T) -> Self { + let data = Box::into_raw(Box::new(data)); + Self::Scoped(ScopedUserDataVariant::Boxed(RefCell::new(data))) + } + + /// Returns `true` if it's safe to destroy the container. + /// + /// It's safe to destroy the container if the reference count is greater than 1 or the lock is + /// not acquired. + #[inline(always)] + pub(crate) fn is_safe_to_destroy(&self) -> bool { + match self { + Self::Owned(variant) => variant.strong_count() > 1 || !variant.raw_lock().is_locked(), + Self::Scoped(_) => false, + } + } + + /// Returns `true` if the container has exclusive access to the value. + #[inline(always)] + pub(crate) fn has_exclusive_access(&self) -> bool { + match self { + Self::Owned(variant) => !variant.raw_lock().is_locked(), + Self::Scoped(_) => false, + } + } + + #[inline] + pub(crate) fn try_borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { + match self { + Self::Owned(data) => data.try_borrow_scoped(f), + Self::Scoped(ScopedUserDataVariant::Ref(value)) => Ok(f(unsafe { &**value })), + Self::Scoped(ScopedUserDataVariant::RefMut(value) | ScopedUserDataVariant::Boxed(value)) => { + let t = value.try_borrow().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(unsafe { &**t })) + } + } + } + + #[inline] + pub(crate) fn try_borrow_scoped_mut(&self, f: impl FnOnce(&mut T) -> R) -> Result { + match self { + Self::Owned(data) => data.try_borrow_scoped_mut(f), + Self::Scoped(ScopedUserDataVariant::Ref(_)) => Err(Error::UserDataBorrowMutError), + Self::Scoped(ScopedUserDataVariant::RefMut(value) | ScopedUserDataVariant::Boxed(value)) => { + let mut t = value + .try_borrow_mut() + .map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(unsafe { &mut **t })) + } + } + } +} diff --git a/src/userdata/lock.rs b/src/userdata/lock.rs new file mode 100644 index 00000000..901a557b --- /dev/null +++ b/src/userdata/lock.rs @@ -0,0 +1,169 @@ +pub(crate) trait UserDataLock { + fn is_locked(&self) -> bool; + fn try_lock_shared(&self) -> bool; + fn try_lock_exclusive(&self) -> bool; + + unsafe fn unlock_shared(&self); + unsafe fn unlock_exclusive(&self); + + fn try_lock_shared_guarded(&self) -> Result, ()> { + if self.try_lock_shared() { + Ok(LockGuard { + lock: self, + exclusive: false, + }) + } else { + Err(()) + } + } + + fn try_lock_exclusive_guarded(&self) -> Result, ()> { + if self.try_lock_exclusive() { + Ok(LockGuard { + lock: self, + exclusive: true, + }) + } else { + Err(()) + } + } +} + +pub(crate) struct LockGuard<'a, L: UserDataLock + ?Sized> { + lock: &'a L, + exclusive: bool, +} + +impl Drop for LockGuard<'_, L> { + fn drop(&mut self) { + unsafe { + if self.exclusive { + self.lock.unlock_exclusive(); + } else { + self.lock.unlock_shared(); + } + } + } +} + +pub(crate) use lock_impl::{RawLock, RwLock}; + +#[cfg(not(feature = "send"))] +#[cfg(not(tarpaulin_include))] +mod lock_impl { + use std::cell::{Cell, UnsafeCell}; + + // Positive values represent the number of read references. + // Negative values represent the number of write references (only one allowed). + pub(crate) type RawLock = Cell; + + const UNUSED: isize = 0; + + impl super::UserDataLock for RawLock { + #[inline(always)] + fn is_locked(&self) -> bool { + self.get() != UNUSED + } + + #[inline(always)] + fn try_lock_shared(&self) -> bool { + let flag = self.get().checked_add(1).expect("userdata lock count overflow"); + if flag <= UNUSED { + return false; + } + self.set(flag); + true + } + + #[inline(always)] + fn try_lock_exclusive(&self) -> bool { + let flag = self.get(); + if flag != UNUSED { + return false; + } + self.set(UNUSED - 1); + true + } + + #[inline(always)] + unsafe fn unlock_shared(&self) { + let flag = self.get(); + debug_assert!(flag > UNUSED); + self.set(flag - 1); + } + + #[inline(always)] + unsafe fn unlock_exclusive(&self) { + let flag = self.get(); + debug_assert!(flag < UNUSED); + self.set(flag + 1); + } + } + + /// A cheap single-threaded read-write lock pairing a `parking_lot::RwLock` type. + pub(crate) struct RwLock { + lock: RawLock, + data: UnsafeCell, + } + + impl RwLock { + /// Creates a new `RwLock` containing the given value. + #[inline(always)] + pub(crate) fn new(value: T) -> Self { + RwLock { + lock: RawLock::new(UNUSED), + data: UnsafeCell::new(value), + } + } + + /// Returns a reference to the underlying raw lock. + #[inline(always)] + pub(crate) unsafe fn raw(&self) -> &RawLock { + &self.lock + } + + /// Returns a raw pointer to the underlying data. + #[inline(always)] + pub(crate) fn data_ptr(&self) -> *mut T { + self.data.get() + } + + /// Consumes this `RwLock`, returning the underlying data. + #[inline(always)] + pub(crate) fn into_inner(self) -> T { + self.data.into_inner() + } + } +} + +#[cfg(feature = "send")] +mod lock_impl { + pub(crate) use parking_lot::{RawRwLock as RawLock, RwLock}; + + impl super::UserDataLock for RawLock { + #[inline(always)] + fn is_locked(&self) -> bool { + parking_lot::lock_api::RawRwLock::is_locked(self) + } + + #[inline(always)] + fn try_lock_shared(&self) -> bool { + parking_lot::lock_api::RawRwLock::try_lock_shared(self) + } + + #[inline(always)] + fn try_lock_exclusive(&self) -> bool { + parking_lot::lock_api::RawRwLock::try_lock_exclusive(self) + } + + #[inline(always)] + unsafe fn unlock_shared(&self) { + parking_lot::lock_api::RawRwLock::unlock_shared(self) + } + + #[inline(always)] + unsafe fn unlock_exclusive(&self) { + parking_lot::lock_api::RawRwLock::unlock_exclusive(self) + } + } +} diff --git a/src/userdata/object.rs b/src/userdata/object.rs new file mode 100644 index 00000000..12019e01 --- /dev/null +++ b/src/userdata/object.rs @@ -0,0 +1,102 @@ +use crate::Function; +use crate::error::{Error, Result}; +use crate::state::WeakLua; +use crate::table::Table; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, ObjectLike}; +use crate::userdata::AnyUserData; +use crate::value::Value; + +#[cfg(feature = "async")] +use crate::function::AsyncCallFuture; + +impl ObjectLike for AnyUserData { + #[inline] + fn get(&self, key: impl IntoLua) -> Result { + // `lua_gettable` method used under the hood can work with any Lua value + // that has `__index` metamethod + Table(self.0.clone()).get_protected(key) + } + + #[inline] + fn set(&self, key: impl IntoLua, value: impl IntoLua) -> Result<()> { + // `lua_settable` method used under the hood can work with any Lua value + // that has `__newindex` metamethod + Table(self.0.clone()).set_protected(key, value) + } + + #[inline] + fn call(&self, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti, + { + Function(self.0.clone()).call(args) + } + + #[cfg(feature = "async")] + #[inline] + fn call_async(&self, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti, + { + Function(self.0.clone()).call_async(args) + } + + #[inline] + fn call_method(&self, name: &str, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti, + { + self.call_function(name, (self, args)) + } + + #[cfg(feature = "async")] + fn call_async_method(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti, + { + self.call_async_function(name, (self, args)) + } + + fn call_function(&self, name: &str, args: impl IntoLuaMulti) -> Result + where + R: FromLuaMulti, + { + match self.get(name)? { + Value::Function(func) => func.call(args), + val => { + let msg = format!("attempt to call a {} value (function '{name}')", val.type_name()); + Err(Error::RuntimeError(msg)) + } + } + } + + #[cfg(feature = "async")] + fn call_async_function(&self, name: &str, args: impl IntoLuaMulti) -> AsyncCallFuture + where + R: FromLuaMulti, + { + match self.get(name) { + Ok(Value::Function(func)) => func.call_async(args), + Ok(val) => { + let msg = format!("attempt to call a {} value (function '{name}')", val.type_name()); + AsyncCallFuture::error(Error::RuntimeError(msg)) + } + Err(err) => AsyncCallFuture::error(err), + } + } + + #[inline] + fn to_string(&self) -> Result { + Value::UserData(self.clone()).to_string() + } + + #[inline] + fn to_value(&self) -> Value { + Value::UserData(self.clone()) + } + + #[inline] + fn weak_lua(&self) -> &WeakLua { + &self.0.lua + } +} diff --git a/src/userdata/ref.rs b/src/userdata/ref.rs new file mode 100644 index 00000000..131b84d6 --- /dev/null +++ b/src/userdata/ref.rs @@ -0,0 +1,536 @@ +use std::any::{TypeId, type_name}; +use std::ops::{Deref, DerefMut}; +use std::os::raw::c_int; +use std::{fmt, mem}; + +use crate::error::{Error, Result}; +use crate::state::{Lua, RawLua}; +use crate::traits::FromLua; +use crate::userdata::AnyUserData; +use crate::util::{check_stack, get_userdata, take_userdata}; +use crate::value::Value; + +use super::cell::{UserDataStorage, UserDataVariant}; +use super::lock::{LockGuard, RawLock, UserDataLock}; + +#[cfg(feature = "userdata-wrappers")] +use { + parking_lot::{ + Mutex as MutexPL, MutexGuard as MutexGuardPL, RwLock as RwLockPL, + RwLockReadGuard as RwLockReadGuardPL, RwLockWriteGuard as RwLockWriteGuardPL, + }, + std::sync::Arc, +}; +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +use { + std::cell::{Ref, RefCell, RefMut}, + std::rc::Rc, +}; + +/// A wrapper type for a userdata value that provides read access. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRef { + // It's important to drop the guard first, as it refers to the `inner` data. + _guard: LockGuard<'static, RawLock>, + inner: UserDataRefInner, +} + +impl Deref for UserDataRef { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + &self.inner + } +} + +impl fmt::Debug for UserDataRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Display for UserDataRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl TryFrom> for UserDataRef { + type Error = Error; + + #[inline] + fn try_from(variant: UserDataVariant) -> Result { + // Shared (read) lock is always correct: + // - with `send` feature, `T: Sync` is guaranteed by the `MaybeSync` bound on userdata creation + // - without `send` feature, single-threaded access makes shared lock safe for any `T` + let guard = variant.raw_lock().try_lock_shared_guarded(); + let guard = guard.map_err(|_| Error::UserDataBorrowError)?; + let guard = unsafe { mem::transmute::, LockGuard<'static, _>>(guard) }; + Ok(UserDataRef::from_parts(UserDataRefInner::Default(variant), guard)) + } +} + +impl FromLua for UserDataRef { + fn from_lua(value: Value, _: &Lua) -> Result { + try_value_to_userdata::(value)?.borrow() + } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Self::borrow_from_stack(lua, lua.state(), idx) + } +} + +impl UserDataRef { + #[inline(always)] + fn from_parts(inner: UserDataRefInner, guard: LockGuard<'static, RawLock>) -> Self { + Self { _guard: guard, inner } + } + + #[cfg(feature = "userdata-wrappers")] + fn remap( + self, + f: impl FnOnce(UserDataVariant) -> Result>, + ) -> Result> { + match &self.inner { + UserDataRefInner::Default(variant) => { + let inner = f(variant.clone())?; + Ok(UserDataRef::from_parts(inner, self._guard)) + } + _ => Err(Error::UserDataTypeMismatch), + } + } + + pub(crate) unsafe fn borrow_from_stack( + lua: &RawLua, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result { + let type_id = lua.get_userdata_type_id::(state, idx)?; + match type_id { + Some(type_id) if type_id == TypeId::of::() => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_owned() + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>() => { + let ud = get_userdata::>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_rc()) + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_rc_refcell()) + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>() => { + let ud = get_userdata::>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc_mutex_pl()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc_rwlock_pl()) + } + _ => Err(Error::UserDataTypeMismatch), + } + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRef> { + fn transform_rc(self) -> Result> { + self.remap(|variant| Ok(UserDataRefInner::Rc(variant))) + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRef>> { + fn transform_rc_refcell(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let r#ref = obj.try_borrow().map_err(|_| Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, Ref<'static, T>>(r#ref); + Ok(UserDataRefInner::RcRefCell(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef> { + fn transform_arc(self) -> Result> { + self.remap(|variant| Ok(UserDataRefInner::Arc(variant))) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef>> { + fn transform_arc_mutex_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_lock().ok_or(Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, MutexGuardPL<'static, T>>(guard); + Ok(UserDataRefInner::ArcMutexPL(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef>> { + fn transform_arc_rwlock_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_read().ok_or(Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, RwLockReadGuardPL<'static, T>>(guard); + Ok(UserDataRefInner::ArcRwLockPL(borrow, variant)) + }) + } +} + +#[allow(unused)] +enum UserDataRefInner { + Default(UserDataVariant), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Rc(UserDataVariant>), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + RcRefCell(Ref<'static, T>, UserDataVariant>>), + + #[cfg(feature = "userdata-wrappers")] + Arc(UserDataVariant>), + #[cfg(feature = "userdata-wrappers")] + ArcMutexPL(MutexGuardPL<'static, T>, UserDataVariant>>), + #[cfg(feature = "userdata-wrappers")] + ArcRwLockPL(RwLockReadGuardPL<'static, T>, UserDataVariant>>), +} + +impl Deref for UserDataRefInner { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Default(inner) => unsafe { &*inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::Rc(inner) => unsafe { &*Rc::as_ptr(&*inner.as_ptr()) }, + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::Arc(inner) => unsafe { &*Arc::as_ptr(&*inner.as_ptr()) }, + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +/// A wrapper type for a userdata value that provides read and write access. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRefMut { + // It's important to drop the guard first, as it refers to the `inner` data. + _guard: LockGuard<'static, RawLock>, + inner: UserDataRefMutInner, +} + +impl Deref for UserDataRefMut { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for UserDataRefMut { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl fmt::Debug for UserDataRefMut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Display for UserDataRefMut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl TryFrom> for UserDataRefMut { + type Error = Error; + + #[inline] + fn try_from(variant: UserDataVariant) -> Result { + let guard = variant.raw_lock().try_lock_exclusive_guarded(); + let guard = guard.map_err(|_| Error::UserDataBorrowMutError)?; + let guard = unsafe { mem::transmute::, LockGuard<'static, _>>(guard) }; + Ok(UserDataRefMut::from_parts( + UserDataRefMutInner::Default(variant), + guard, + )) + } +} + +impl FromLua for UserDataRefMut { + fn from_lua(value: Value, _: &Lua) -> Result { + try_value_to_userdata::(value)?.borrow_mut() + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Self::borrow_from_stack(lua, lua.state(), idx) + } +} + +impl UserDataRefMut { + #[inline(always)] + fn from_parts(inner: UserDataRefMutInner, guard: LockGuard<'static, RawLock>) -> Self { + Self { _guard: guard, inner } + } + + #[cfg(feature = "userdata-wrappers")] + fn remap( + self, + f: impl FnOnce(UserDataVariant) -> Result>, + ) -> Result> { + match &self.inner { + UserDataRefMutInner::Default(variant) => { + let inner = f(variant.clone())?; + Ok(UserDataRefMut::from_parts(inner, self._guard)) + } + _ => Err(Error::UserDataTypeMismatch), + } + } + + pub(crate) unsafe fn borrow_from_stack( + lua: &RawLua, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result { + let type_id = lua.get_userdata_type_id::(state, idx)?; + match type_id { + Some(type_id) if type_id == TypeId::of::() => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_owned_mut() + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>() => Err(Error::UserDataBorrowMutError), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_rc_refcell()) + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>() => Err(Error::UserDataBorrowMutError), + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_arc_mutex_pl()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_arc_rwlock_pl()) + } + _ => Err(Error::UserDataTypeMismatch), + } + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRefMut>> { + fn transform_rc_refcell(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let refmut = obj.try_borrow_mut().map_err(|_| Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, RefMut<'static, T>>(refmut); + Ok(UserDataRefMutInner::RcRefCell(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRefMut>> { + fn transform_arc_mutex_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_lock().ok_or(Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, MutexGuardPL<'static, T>>(guard); + Ok(UserDataRefMutInner::ArcMutexPL(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRefMut>> { + fn transform_arc_rwlock_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_write().ok_or(Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, RwLockWriteGuardPL<'static, T>>(guard); + Ok(UserDataRefMutInner::ArcRwLockPL(borrow, variant)) + }) + } +} + +#[allow(unused)] +enum UserDataRefMutInner { + Default(UserDataVariant), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + RcRefCell(RefMut<'static, T>, UserDataVariant>>), + + #[cfg(feature = "userdata-wrappers")] + ArcMutexPL(MutexGuardPL<'static, T>, UserDataVariant>>), + #[cfg(feature = "userdata-wrappers")] + ArcRwLockPL(RwLockWriteGuardPL<'static, T>, UserDataVariant>>), +} + +impl Deref for UserDataRefMutInner { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Default(inner) => unsafe { &*inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +impl DerefMut for UserDataRefMutInner { + #[inline] + fn deref_mut(&mut self) -> &mut T { + match self { + Self::Default(inner) => unsafe { &mut *inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +/// A wrapper type that takes ownership of a userdata value. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua by taking +/// ownership of it. +/// The original Lua userdata is marked as destructed and cannot be used further. +pub struct UserDataOwned(pub T); + +impl Deref for UserDataOwned { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + &self.0 + } +} + +impl DerefMut for UserDataOwned { + #[inline] + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl fmt::Debug for UserDataOwned { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Display for UserDataOwned { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl FromLua for UserDataOwned { + fn from_lua(value: Value, _: &Lua) -> Result { + try_value_to_userdata::(value)?.take().map(UserDataOwned) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let state = lua.state(); + let type_id = lua.get_userdata_type_id::(state, idx)?; + match type_id { + Some(type_id) if type_id == TypeId::of::() => { + let ud = get_userdata::>(state, idx); + if (*ud).has_exclusive_access() { + check_stack(state, 1)?; + take_userdata::>(state, idx) + .into_inner() + .map(UserDataOwned) + } else { + Err(Error::UserDataBorrowMutError) + } + } + _ => Err(Error::UserDataTypeMismatch), + } + } +} + +#[inline] +fn try_value_to_userdata(value: Value) -> Result { + match value { + Value::UserData(ud) => Ok(ud), + _ => Err(Error::from_lua_conversion( + value.type_name(), + "userdata", + format!("expected userdata of type {}", type_name::()), + )), + } +} + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(UserDataRef<()>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_not_impl_all!(UserDataRef>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(UserDataRefMut<()>: Sync, Send); + #[cfg(feature = "send")] + static_assertions::assert_not_impl_all!(UserDataRefMut>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(UserDataOwned<()>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_not_impl_all!(UserDataOwned>: Send, Sync); + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_all!(UserDataRef<()>: Send, Sync); + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_all!(UserDataRefMut<()>: Send, Sync); +} diff --git a/src/userdata/registry.rs b/src/userdata/registry.rs new file mode 100644 index 00000000..c5138140 --- /dev/null +++ b/src/userdata/registry.rs @@ -0,0 +1,684 @@ +#![allow(clippy::await_holding_refcell_ref, clippy::await_holding_lock)] + +use std::any::TypeId; +use std::cell::RefCell; +use std::marker::PhantomData; +use std::os::raw::c_void; + +use crate::error::{Error, Result}; +use crate::state::{Lua, LuaGuard}; +use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; +use crate::types::{Callback, MaybeSend}; +use crate::userdata::{ + AnyUserData, MetaMethod, TypeIdHints, UserData, UserDataFields, UserDataMethods, UserDataStorage, + borrow_userdata_scoped, borrow_userdata_scoped_mut, +}; +use crate::util::short_type_name; +use crate::value::Value; + +#[cfg(feature = "async")] +use { + crate::types::AsyncCallback, + crate::userdata::{UserDataRef, UserDataRefMut}, + std::future::{self, Future}, +}; + +#[derive(Clone, Copy)] +enum UserDataType { + Shared(TypeIdHints), + Unique(*mut c_void), +} + +/// Handle to registry for userdata methods and metamethods. +pub struct UserDataRegistry { + lua: LuaGuard, + raw: RawUserDataRegistry, + r#type: UserDataType, + _phantom: PhantomData, +} + +pub(crate) struct RawUserDataRegistry { + // Fields + pub(crate) fields: Vec<(String, Result)>, + pub(crate) field_getters: Vec<(String, Callback)>, + pub(crate) field_setters: Vec<(String, Callback)>, + pub(crate) meta_fields: Vec<(String, Result)>, + + // Methods + pub(crate) methods: Vec<(String, Callback)>, + #[cfg(feature = "async")] + pub(crate) async_methods: Vec<(String, AsyncCallback)>, + pub(crate) meta_methods: Vec<(String, Callback)>, + #[cfg(feature = "async")] + pub(crate) async_meta_methods: Vec<(String, AsyncCallback)>, + + pub(crate) destructor: ffi::lua_CFunction, + pub(crate) type_id: Option, + pub(crate) type_name: String, + + #[cfg(feature = "luau")] + pub(crate) enable_namecall: bool, +} + +impl UserDataType { + #[inline] + pub(crate) fn type_id(&self) -> Option { + match self { + UserDataType::Shared(hints) => Some(hints.type_id()), + UserDataType::Unique(_) => None, + } + } +} + +#[cfg(feature = "send")] +unsafe impl Send for UserDataType {} + +impl UserDataRegistry { + #[inline(always)] + pub(crate) fn new(lua: &Lua) -> Self { + Self::with_type(lua, UserDataType::Shared(TypeIdHints::new::())) + } +} + +impl UserDataRegistry { + #[inline(always)] + pub(crate) fn new_unique(lua: &Lua, ud_ptr: *mut c_void) -> Self { + Self::with_type(lua, UserDataType::Unique(ud_ptr)) + } + + #[inline(always)] + fn with_type(lua: &Lua, r#type: UserDataType) -> Self { + let raw = RawUserDataRegistry { + fields: Vec::new(), + field_getters: Vec::new(), + field_setters: Vec::new(), + meta_fields: Vec::new(), + methods: Vec::new(), + #[cfg(feature = "async")] + async_methods: Vec::new(), + meta_methods: Vec::new(), + #[cfg(feature = "async")] + async_meta_methods: Vec::new(), + destructor: super::util::destroy_userdata_storage::, + type_id: r#type.type_id(), + type_name: short_type_name::(), + #[cfg(feature = "luau")] + enable_namecall: false, + }; + + UserDataRegistry { + lua: lua.lock_arc(), + raw, + r#type, + _phantom: PhantomData, + } + } + + /// Enables support for the namecall optimization in Luau. + /// + /// This enables methods resolution optimization in Luau for complex userdata types with methods + /// and field getters. When enabled, Luau will use a faster lookup path for method calls when a + /// specific syntax is used (e.g. `obj:method()`. + /// + /// This optimization does not play well with async methods, custom `__index` metamethod and + /// field getters as functions. So, it is disabled by default. + /// + /// Use with caution. + #[doc(hidden)] + #[cfg(feature = "luau")] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn enable_namecall(&mut self) { + self.raw.enable_namecall = true; + } + + fn box_method(&self, name: &str, method: M) -> Callback + where + M: Fn(&Lua, &T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + macro_rules! try_self_arg { + ($res:expr) => { + $res.map_err(|err| Error::bad_self_argument(&name, err))? + }; + } + + let target_type = self.r#type; + Box::new(move |rawlua, nargs| unsafe { + if nargs == 0 { + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); + } + let state = rawlua.state(); + // Find absolute "self" index before processing args + let self_index = ffi::lua_absindex(state, -nargs); + // Self was at position 1, so we pass 2 here + let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); + + match target_type { + #[rustfmt::skip] + UserDataType::Shared(type_hints) => { + let type_id = try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + try_self_arg!(borrow_userdata_scoped(state, self_index, type_id, type_hints, |ud| { + method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) + })) + } + UserDataType::Unique(target_ptr) if ffi::lua_touserdata(state, self_index) == target_ptr => { + let ud = target_ptr as *mut UserDataStorage; + try_self_arg!((*ud).try_borrow_scoped(|ud| { + method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) + })) + } + UserDataType::Unique(_) => { + try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)) + } + } + }) + } + + fn box_method_mut(&self, name: &str, method: M) -> Callback + where + M: FnMut(&Lua, &mut T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + macro_rules! try_self_arg { + ($res:expr) => { + $res.map_err(|err| Error::bad_self_argument(&name, err))? + }; + } + + let method = RefCell::new(method); + let target_type = self.r#type; + Box::new(move |rawlua, nargs| unsafe { + let mut method = method.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; + if nargs == 0 { + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); + } + let state = rawlua.state(); + // Find absolute "self" index before processing args + let self_index = ffi::lua_absindex(state, -nargs); + // Self was at position 1, so we pass 2 here + let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); + + match target_type { + #[rustfmt::skip] + UserDataType::Shared(type_hints) => { + let type_id = try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + try_self_arg!(borrow_userdata_scoped_mut(state, self_index, type_id, type_hints, |ud| { + method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) + })) + } + UserDataType::Unique(target_ptr) if ffi::lua_touserdata(state, self_index) == target_ptr => { + let ud = target_ptr as *mut UserDataStorage; + try_self_arg!((*ud).try_borrow_scoped_mut(|ud| { + method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) + })) + } + UserDataType::Unique(_) => { + try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)) + } + } + }) + } + + #[cfg(feature = "async")] + fn box_async_method(&self, name: &str, method: M) -> AsyncCallback + where + T: 'static, + M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + macro_rules! try_self_arg { + ($res:expr) => { + match $res { + Ok(res) => res, + Err(err) => return Box::pin(future::ready(Err(Error::bad_self_argument(&name, err)))), + } + }; + } + + Box::new(move |rawlua, nargs| unsafe { + if nargs == 0 { + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); + } + // Stack will be empty when polling the future, keep `self` on the ref thread + let self_ud = try_self_arg!(AnyUserData::from_stack(-nargs, rawlua)); + let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); + + let self_ud = try_self_arg!(self_ud.borrow()); + let args = match args { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = method(lua.clone(), self_ud, args); + // Lua is locked when the future is polled + Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) }) + }) + } + + #[cfg(feature = "async")] + fn box_async_method_mut(&self, name: &str, method: M) -> AsyncCallback + where + T: 'static, + M: Fn(Lua, UserDataRefMut, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + macro_rules! try_self_arg { + ($res:expr) => { + match $res { + Ok(res) => res, + Err(err) => return Box::pin(future::ready(Err(Error::bad_self_argument(&name, err)))), + } + }; + } + + Box::new(move |rawlua, nargs| unsafe { + if nargs == 0 { + let err = Error::from_lua_conversion("missing argument", "userdata", None); + try_self_arg!(Err(err)); + } + // Stack will be empty when polling the future, keep `self` on the ref thread + let self_ud = try_self_arg!(AnyUserData::from_stack(-nargs, rawlua)); + let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); + + let self_ud = try_self_arg!(self_ud.borrow_mut()); + let args = match args { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = method(lua.clone(), self_ud, args); + // Lua is locked when the future is polled + Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) }) + }) + } + + fn box_function(&self, name: &str, function: F) -> Callback + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + Box::new(move |lua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, Some(&name), lua)?; + function(lua.lua(), args)?.push_into_stack_multi(lua) + }) + } + + fn box_function_mut(&self, name: &str, function: F) -> Callback + where + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + let function = RefCell::new(function); + Box::new(move |lua, nargs| unsafe { + let function = &mut *function + .try_borrow_mut() + .map_err(|_| Error::RecursiveMutCallback)?; + let args = A::from_stack_args(nargs, 1, Some(&name), lua)?; + function(lua.lua(), args)?.push_into_stack_multi(lua) + }) + } + + #[cfg(feature = "async")] + fn box_async_function(&self, name: &str, function: F) -> AsyncCallback + where + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = get_function_name::(name); + Box::new(move |rawlua, nargs| unsafe { + let args = match A::from_stack_args(nargs, 1, Some(&name), rawlua) { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = function(lua.clone(), args); + Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) }) + }) + } + + pub(crate) fn check_meta_field(lua: &Lua, name: &str, value: impl IntoLua) -> Result { + let value = value.into_lua(lua)?; + if name == MetaMethod::Index || name == MetaMethod::NewIndex { + match value { + Value::Nil | Value::Table(_) | Value::Function(_) => {} + _ => { + return Err(Error::MetaMethodTypeError { + method: name.to_string(), + type_name: value.type_name(), + message: Some("expected nil, table or function".to_string()), + }); + } + } + } + value.into_lua(lua) + } + + #[inline(always)] + pub(crate) fn into_raw(self) -> RawUserDataRegistry { + self.raw + } +} + +// Returns function name for the type `T`, without the module path +fn get_function_name(name: &str) -> String { + format!("{}.{name}", short_type_name::()) +} + +impl UserDataFields for UserDataRegistry { + fn add_field(&mut self, name: impl Into, value: V) + where + V: IntoLua + 'static, + { + let name = name.into(); + self.raw.fields.push((name, value.into_lua(self.lua.lua()))); + } + + fn add_field_method_get(&mut self, name: impl Into, method: M) + where + M: Fn(&Lua, &T) -> Result + MaybeSend + 'static, + R: IntoLua, + { + let name = name.into(); + let callback = self.box_method(&name, move |lua, data, ()| method(lua, data)); + self.raw.field_getters.push((name, callback)); + } + + fn add_field_method_set(&mut self, name: impl Into, method: M) + where + M: FnMut(&Lua, &mut T, A) -> Result<()> + MaybeSend + 'static, + A: FromLua, + { + let name = name.into(); + let callback = self.box_method_mut(&name, method); + self.raw.field_setters.push((name, callback)); + } + + fn add_field_function_get(&mut self, name: impl Into, function: F) + where + F: Fn(&Lua, AnyUserData) -> Result + MaybeSend + 'static, + R: IntoLua, + { + let name = name.into(); + let callback = self.box_function(&name, function); + self.raw.field_getters.push((name, callback)); + } + + fn add_field_function_set(&mut self, name: impl Into, mut function: F) + where + F: FnMut(&Lua, AnyUserData, A) -> Result<()> + MaybeSend + 'static, + A: FromLua, + { + let name = name.into(); + let callback = self.box_function_mut(&name, move |lua, (data, val)| function(lua, data, val)); + self.raw.field_setters.push((name, callback)); + } + + fn add_meta_field(&mut self, name: impl Into, value: V) + where + V: IntoLua + 'static, + { + let lua = self.lua.lua(); + let name = name.into(); + let field = Self::check_meta_field(lua, &name, value).and_then(|v| v.into_lua(lua)); + self.raw.meta_fields.push((name, field)); + } + + fn add_meta_field_with(&mut self, name: impl Into, f: F) + where + F: FnOnce(&Lua) -> Result + 'static, + R: IntoLua, + { + let lua = self.lua.lua(); + let name = name.into(); + let field = f(lua).and_then(|v| Self::check_meta_field(lua, &name, v).and_then(|v| v.into_lua(lua))); + self.raw.meta_fields.push((name, field)); + } +} + +impl UserDataMethods for UserDataRegistry { + fn add_method(&mut self, name: impl Into, method: M) + where + M: Fn(&Lua, &T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_method(&name, method); + self.raw.methods.push((name, callback)); + } + + fn add_method_mut(&mut self, name: impl Into, method: M) + where + M: FnMut(&Lua, &mut T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_method_mut(&name, method); + self.raw.methods.push((name, callback)); + } + + #[cfg(feature = "async")] + fn add_async_method(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_method(&name, method); + self.raw.async_methods.push((name, callback)); + } + + #[cfg(feature = "async")] + fn add_async_method_mut(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRefMut, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_method_mut(&name, method); + self.raw.async_methods.push((name, callback)); + } + + fn add_function(&mut self, name: impl Into, function: F) + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_function(&name, function); + self.raw.methods.push((name, callback)); + } + + fn add_function_mut(&mut self, name: impl Into, function: F) + where + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_function_mut(&name, function); + self.raw.methods.push((name, callback)); + } + + #[cfg(feature = "async")] + fn add_async_function(&mut self, name: impl Into, function: F) + where + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_function(&name, function); + self.raw.async_methods.push((name, callback)); + } + + fn add_meta_method(&mut self, name: impl Into, method: M) + where + M: Fn(&Lua, &T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_method(&name, method); + self.raw.meta_methods.push((name, callback)); + } + + fn add_meta_method_mut(&mut self, name: impl Into, method: M) + where + M: FnMut(&Lua, &mut T, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_method_mut(&name, method); + self.raw.meta_methods.push((name, callback)); + } + + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + fn add_async_meta_method(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_method(&name, method); + self.raw.async_meta_methods.push((name, callback)); + } + + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + fn add_async_meta_method_mut(&mut self, name: impl Into, method: M) + where + T: 'static, + M: Fn(Lua, UserDataRefMut, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti, + MR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_method_mut(&name, method); + self.raw.async_meta_methods.push((name, callback)); + } + + fn add_meta_function(&mut self, name: impl Into, function: F) + where + F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_function(&name, function); + self.raw.meta_methods.push((name, callback)); + } + + fn add_meta_function_mut(&mut self, name: impl Into, function: F) + where + F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, + A: FromLuaMulti, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_function_mut(&name, function); + self.raw.meta_methods.push((name, callback)); + } + + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + fn add_async_meta_function(&mut self, name: impl Into, function: F) + where + F: Fn(Lua, A) -> FR + MaybeSend + 'static, + A: FromLuaMulti, + FR: Future> + MaybeSend + 'static, + R: IntoLuaMulti, + { + let name = name.into(); + let callback = self.box_async_function(&name, function); + self.raw.async_meta_methods.push((name, callback)); + } +} + +macro_rules! lua_userdata_impl { + ($type:ty) => { + impl UserData for $type { + fn register(registry: &mut UserDataRegistry) { + let mut orig_registry = UserDataRegistry::new(registry.lua.lua()); + T::register(&mut orig_registry); + + // Copy all fields, methods, etc. from the original registry + (registry.raw.fields).extend(orig_registry.raw.fields); + (registry.raw.field_getters).extend(orig_registry.raw.field_getters); + (registry.raw.field_setters).extend(orig_registry.raw.field_setters); + (registry.raw.meta_fields).extend(orig_registry.raw.meta_fields); + (registry.raw.methods).extend(orig_registry.raw.methods); + #[cfg(feature = "async")] + (registry.raw.async_methods).extend(orig_registry.raw.async_methods); + (registry.raw.meta_methods).extend(orig_registry.raw.meta_methods); + #[cfg(feature = "async")] + (registry.raw.async_meta_methods).extend(orig_registry.raw.async_meta_methods); + } + } + }; +} + +// A special proxy object for UserData +pub(crate) struct UserDataProxy(pub(crate) PhantomData); + +// `UserDataProxy` holds no real `T` value, only a type marker, so it is always safe to send/share. +#[cfg(feature = "send")] +unsafe impl Send for UserDataProxy {} +#[cfg(feature = "send")] +unsafe impl Sync for UserDataProxy {} + +lua_userdata_impl!(UserDataProxy); + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +lua_userdata_impl!(std::rc::Rc); +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +lua_userdata_impl!(std::rc::Rc>); +#[cfg(feature = "userdata-wrappers")] +lua_userdata_impl!(std::sync::Arc); +#[cfg(feature = "userdata-wrappers")] +lua_userdata_impl!(std::sync::Arc>); +#[cfg(feature = "userdata-wrappers")] +lua_userdata_impl!(std::sync::Arc>); +#[cfg(feature = "userdata-wrappers")] +lua_userdata_impl!(std::sync::Arc>); +#[cfg(feature = "userdata-wrappers")] +lua_userdata_impl!(std::sync::Arc>); + +#[cfg(test)] +mod assertions { + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(super::RawUserDataRegistry: Send); +} diff --git a/src/userdata/util.rs b/src/userdata/util.rs new file mode 100644 index 00000000..6c5f0f8f --- /dev/null +++ b/src/userdata/util.rs @@ -0,0 +1,479 @@ +use std::any::TypeId; +use std::os::raw::c_int; +use std::ptr; + +use rustc_hash::FxHashMap; + +use super::UserDataStorage; +use crate::error::{Error, Result}; +use crate::types::CallbackPtr; +use crate::util::{get_userdata, rawget_field, rawset_field, take_userdata}; + +// Userdata type hints, used to match types of wrapped userdata +#[derive(Clone, Copy)] +pub(crate) struct TypeIdHints { + t: TypeId, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc: TypeId, + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc_refcell: TypeId, + + #[cfg(feature = "userdata-wrappers")] + arc: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_mutex: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_rwlock: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_pl_mutex: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_pl_rwlock: TypeId, +} + +impl TypeIdHints { + pub(crate) fn new() -> Self { + Self { + t: TypeId::of::(), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc: TypeId::of::>(), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc_refcell: TypeId::of::>>(), + + #[cfg(feature = "userdata-wrappers")] + arc: TypeId::of::>(), + #[cfg(feature = "userdata-wrappers")] + arc_mutex: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_rwlock: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_pl_mutex: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_pl_rwlock: TypeId::of::>>(), + } + } + + #[inline(always)] + pub(crate) fn type_id(&self) -> TypeId { + self.t + } +} + +pub(crate) unsafe fn borrow_userdata_scoped( + state: *mut ffi::lua_State, + idx: c_int, + type_id: Option, + type_hints: TypeIdHints, + f: impl FnOnce(&T) -> R, +) -> Result { + match type_id { + Some(type_id) if type_id == type_hints.t => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc_refcell => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_lock().ok_or(Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_read().ok_or(Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + _ => Err(Error::UserDataTypeMismatch), + } +} + +pub(crate) unsafe fn borrow_userdata_scoped_mut( + state: *mut ffi::lua_State, + idx: c_int, + type_id: Option, + type_hints: TypeIdHints, + f: impl FnOnce(&mut T) -> R, +) -> Result { + match type_id { + Some(type_id) if type_id == type_hints.t => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| f(ud)) + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| match std::rc::Rc::get_mut(ud) { + Some(ud) => Ok(f(ud)), + None => Err(Error::UserDataBorrowMutError), + })? + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc_refcell => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let mut ud = ud.try_borrow_mut().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| match std::sync::Arc::get_mut(ud) { + Some(ud) => Ok(f(ud)), + None => Err(Error::UserDataBorrowMutError), + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_lock().ok_or(Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_write().ok_or(Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + _ => Err(Error::UserDataTypeMismatch), + } +} + +// Populates the given table with the appropriate members to be a userdata metatable for the given +// type. This function takes the given table at the `metatable` index, and adds an appropriate +// `__gc` member to it for the given type and a `__metatable` entry to protect the table from script +// access. The function also, if given a `field_getters` or `methods` tables, will create an +// `__index` metamethod (capturing previous one) to lookup in `field_getters` first, then `methods` +// and falling back to the captured `__index` if no matches found. +// The same is also applicable for `__newindex` metamethod and `field_setters` table. +// Internally uses 9 stack spaces and does not call checkstack. +pub(crate) unsafe fn init_userdata_metatable( + state: *mut ffi::lua_State, + metatable: c_int, + field_getters: Option, + field_setters: Option, + methods: Option, + _methods_map: Option, CallbackPtr>>, // Used only in Luau for `__namecall` +) -> Result<()> { + if field_getters.is_some() || methods.is_some() { + // Push `__index` generator function + init_userdata_metatable_index(state)?; + + let index_type = rawget_field(state, metatable, "__index")?; + match index_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + for &idx in &[field_getters, methods] { + if let Some(idx) = idx { + ffi::lua_pushvalue(state, idx); + } else { + ffi::lua_pushnil(state); + } + } + + // Generate `__index` + protect_lua!(state, 4, 1, fn(state) ffi::lua_call(state, 3, 1))?; + } + _ => mlua_panic!("improper `__index` type: {}", index_type), + } + + rawset_field(state, metatable, "__index")?; + + #[cfg(feature = "luau")] + if let Some(methods_map) = _methods_map { + // In Luau we can speedup method calls by providing a dedicated `__namecall` metamethod + push_userdata_metatable_namecall(state, methods_map)?; + rawset_field(state, metatable, "__namecall")?; + } + } + + if let Some(field_setters) = field_setters { + // Push `__newindex` generator function + init_userdata_metatable_newindex(state)?; + + let newindex_type = rawget_field(state, metatable, "__newindex")?; + match newindex_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + ffi::lua_pushvalue(state, field_setters); + // Generate `__newindex` + protect_lua!(state, 3, 1, fn(state) ffi::lua_call(state, 2, 1))?; + } + _ => mlua_panic!("improper `__newindex` type: {}", newindex_type), + } + + rawset_field(state, metatable, "__newindex")?; + } + + ffi::lua_pushboolean(state, 0); + rawset_field(state, metatable, "__metatable")?; + + Ok(()) +} + +unsafe extern "C-unwind" fn lua_error_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_error(state); +} + +unsafe extern "C-unwind" fn lua_isfunction_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushboolean(state, ffi::lua_isfunction(state, -1)); + 1 +} + +unsafe extern "C-unwind" fn lua_istable_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushboolean(state, ffi::lua_istable(state, -1)); + 1 +} + +unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> { + let index_key = &USERDATA_METATABLE_INDEX as *const u8 as *const _; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, index_key) == ffi::LUA_TFUNCTION { + return Ok(()); + } + ffi::lua_pop(state, 1); + + // Create and cache `__index` generator + let code = cr#" + local error, isfunction, istable = ... + return function (__index, field_getters, methods) + -- Common case: has field getters and index is a table + if field_getters ~= nil and methods == nil and istable(__index) then + return function (self, key) + local field_getter = field_getters[key] + if field_getter ~= nil then + return field_getter(self) + end + return __index[key] + end + end + + return function (self, key) + if field_getters ~= nil then + local field_getter = field_getters[key] + if field_getter ~= nil then + return field_getter(self) + end + end + + if methods ~= nil then + local method = methods[key] + if method ~= nil then + return method + end + end + + if isfunction(__index) then + return __index(self, key) + elseif __index == nil then + error("attempt to get an unknown field '"..key.."'") + else + return __index[key] + end + end + end + "#; + protect_lua!(state, 0, 1, |state| { + let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code.count_bytes(), cstr!("=__mlua_index")); + if ret != ffi::LUA_OK { + ffi::lua_error(state); + } + ffi::lua_pushcfunction(state, lua_error_impl); + ffi::lua_pushcfunction(state, lua_isfunction_impl); + ffi::lua_pushcfunction(state, lua_istable_impl); + ffi::lua_call(state, 3, 1); + + #[cfg(feature = "luau-jit")] + if ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_compile(state, -1); + } + + // Store in the registry + ffi::lua_pushvalue(state, -1); + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, index_key); + }) +} + +unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> { + let newindex_key = &USERDATA_METATABLE_NEWINDEX as *const u8 as *const _; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, newindex_key) == ffi::LUA_TFUNCTION { + return Ok(()); + } + ffi::lua_pop(state, 1); + + // Create and cache `__newindex` generator + let code = cr#" + local error, isfunction = ... + return function (__newindex, field_setters) + return function (self, key, value) + if field_setters ~= nil then + local field_setter = field_setters[key] + if field_setter ~= nil then + field_setter(self, value) + return + end + end + + if isfunction(__newindex) then + __newindex(self, key, value) + elseif __newindex == nil then + error("attempt to set an unknown field '"..key.."'") + else + __newindex[key] = value + end + end + end + "#; + protect_lua!(state, 0, 1, |state| { + let code_len = code.count_bytes(); + let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code_len, cstr!("=__mlua_newindex")); + if ret != ffi::LUA_OK { + ffi::lua_error(state); + } + ffi::lua_pushcfunction(state, lua_error_impl); + ffi::lua_pushcfunction(state, lua_isfunction_impl); + ffi::lua_call(state, 2, 1); + + #[cfg(feature = "luau-jit")] + if ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_compile(state, -1); + } + + // Store in the registry + ffi::lua_pushvalue(state, -1); + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, newindex_key); + }) +} + +#[cfg(feature = "luau")] +unsafe fn push_userdata_metatable_namecall( + state: *mut ffi::lua_State, + methods_map: FxHashMap, CallbackPtr>, +) -> Result<()> { + unsafe extern "C-unwind" fn namecall(state: *mut ffi::lua_State) -> c_int { + let name = ffi::lua_namecallatom(state, ptr::null_mut()); + if name.is_null() { + ffi::luaL_error(state, cstr!("attempt to call an unknown method")); + } + let name_cs = std::ffi::CStr::from_ptr(name); + let methods_map = get_userdata::, CallbackPtr>>(state, ffi::lua_upvalueindex(1)); + let callback_ptr = match (*methods_map).get(name_cs.to_bytes()) { + Some(ptr) => *ptr, + #[rustfmt::skip] + None => ffi::luaL_error(state, cstr!("attempt to call an unknown method '%s'"), name), + }; + crate::state::callback_error_ext(state, ptr::null_mut(), true, |extra, nargs| { + let rawlua = (*extra).raw_lua(); + (*callback_ptr)(rawlua, nargs) + }) + } + + // Automatic destructor is provided for any Luau userdata + crate::util::push_userdata(state, methods_map, true)?; + protect_lua!(state, 1, 1, |state| { + ffi::lua_pushcclosured(state, namecall, cstr!("__namecall"), 1); + }) +} + +// This method is called by Lua GC when it's time to collect the userdata. +// +// This method is usually used to collect internal userdata. +#[cfg(not(feature = "luau"))] +pub(crate) unsafe extern "C-unwind" fn collect_userdata(state: *mut ffi::lua_State) -> c_int { + let ud = get_userdata::(state, -1); + ptr::drop_in_place(ud); + 0 +} + +// This method is called by Luau GC when it's time to collect the userdata. +#[cfg(feature = "luau")] +pub(crate) unsafe extern "C" fn collect_userdata( + state: *mut ffi::lua_State, + ud: *mut std::os::raw::c_void, +) { + // Almost none Lua operations are allowed when destructor is running, + // so we need to set a flag to prevent calling any Lua functions + let extra = (*ffi::lua_callbacks(state)).userdata as *mut crate::state::ExtraData; + (*extra).running_gc = true; + // Luau does not support _any_ panics in destructors (they are declared as "C", NOT as "C-unwind"), + // so any panics will trigger `abort()`. + ptr::drop_in_place(ud as *mut T); + (*extra).running_gc = false; +} + +// This method can be called by user or Lua GC to destroy the userdata. +// It checks if the userdata is safe to destroy and sets the "destroyed" metatable +// to prevent further GC collection. +pub(super) unsafe extern "C-unwind" fn destroy_userdata_storage(state: *mut ffi::lua_State) -> c_int { + let ud = get_userdata::>(state, 1); + if (*ud).is_safe_to_destroy() { + take_userdata::>(state, 1); + ffi::lua_pushboolean(state, 1); + } else { + ffi::lua_pushboolean(state, 0); + } + 1 +} + +static USERDATA_METATABLE_INDEX: u8 = 0; +static USERDATA_METATABLE_NEWINDEX: u8 = 0; diff --git a/src/userdata_impl.rs b/src/userdata_impl.rs deleted file mode 100644 index 913bd732..00000000 --- a/src/userdata_impl.rs +++ /dev/null @@ -1,624 +0,0 @@ -use std::any::TypeId; -use std::cell::{Ref, RefCell, RefMut}; -use std::marker::PhantomData; -use std::sync::{Arc, Mutex, RwLock}; - -use crate::error::{Error, Result}; -use crate::ffi; -use crate::lua::Lua; -use crate::types::{Callback, MaybeSend}; -use crate::userdata::{ - AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods, -}; -use crate::util::{check_stack, get_userdata, StackGuard}; -use crate::value::{FromLua, FromLuaMulti, ToLua, ToLuaMulti, Value}; - -#[cfg(not(feature = "send"))] -use std::rc::Rc; - -#[cfg(feature = "async")] -use { - crate::types::AsyncCallback, - futures_core::future::Future, - futures_util::future::{self, TryFutureExt}, -}; - -pub(crate) struct StaticUserDataMethods<'lua, T: 'static + UserData> { - pub(crate) methods: Vec<(Vec, Callback<'lua, 'static>)>, - #[cfg(feature = "async")] - pub(crate) async_methods: Vec<(Vec, AsyncCallback<'lua, 'static>)>, - pub(crate) meta_methods: Vec<(MetaMethod, Callback<'lua, 'static>)>, - #[cfg(feature = "async")] - pub(crate) async_meta_methods: Vec<(MetaMethod, AsyncCallback<'lua, 'static>)>, - _type: PhantomData, -} - -impl<'lua, T: 'static + UserData> Default for StaticUserDataMethods<'lua, T> { - fn default() -> StaticUserDataMethods<'lua, T> { - StaticUserDataMethods { - methods: Vec::new(), - #[cfg(feature = "async")] - async_methods: Vec::new(), - meta_methods: Vec::new(), - #[cfg(feature = "async")] - async_meta_methods: Vec::new(), - _type: PhantomData, - } - } -} - -impl<'lua, T: 'static + UserData> UserDataMethods<'lua, T> for StaticUserDataMethods<'lua, T> { - fn add_method(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, - { - self.methods - .push((name.as_ref().to_vec(), Self::box_method(method))); - } - - fn add_method_mut(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.methods - .push((name.as_ref().to_vec(), Self::box_method_mut(method))); - } - - #[cfg(feature = "async")] - fn add_async_method(&mut self, name: &S, method: M) - where - T: Clone, - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>, - { - self.async_methods - .push((name.as_ref().to_vec(), Self::box_async_method(method))); - } - - fn add_function(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - self.methods - .push((name.as_ref().to_vec(), Self::box_function(function))); - } - - fn add_function_mut(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - self.methods - .push((name.as_ref().to_vec(), Self::box_function_mut(function))); - } - - #[cfg(feature = "async")] - fn add_async_function(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - self.async_methods - .push((name.as_ref().to_vec(), Self::box_async_function(function))); - } - - fn add_meta_method(&mut self, meta: S, method: M) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, - { - self.meta_methods - .push((meta.into(), Self::box_method(method))); - } - - fn add_meta_method_mut(&mut self, meta: S, method: M) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, - { - self.meta_methods - .push((meta.into(), Self::box_method_mut(method))); - } - - #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_method(&mut self, meta: S, method: M) - where - T: Clone, - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>, - { - self.async_meta_methods - .push((meta.into(), Self::box_async_method(method))); - } - - fn add_meta_function(&mut self, meta: S, function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - self.meta_methods - .push((meta.into(), Self::box_function(function))); - } - - fn add_meta_function_mut(&mut self, meta: S, function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - self.meta_methods - .push((meta.into(), Self::box_function_mut(function))); - } - - #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_function(&mut self, meta: S, function: F) - where - S: Into, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - self.async_meta_methods - .push((meta.into(), Self::box_async_function(function))); - } - - // Below are internal methods used in generated code - - fn add_callback(&mut self, name: Vec, callback: Callback<'lua, 'static>) { - self.methods.push((name, callback)); - } - - #[cfg(feature = "async")] - fn add_async_callback(&mut self, name: Vec, callback: AsyncCallback<'lua, 'static>) { - self.async_methods.push((name, callback)); - } - - fn add_meta_callback(&mut self, meta: MetaMethod, callback: Callback<'lua, 'static>) { - self.meta_methods.push((meta, callback)); - } - - #[cfg(feature = "async")] - fn add_async_meta_callback( - &mut self, - meta: MetaMethod, - callback: AsyncCallback<'lua, 'static>, - ) { - self.async_meta_methods.push((meta, callback)) - } -} - -impl<'lua, T: 'static + UserData> StaticUserDataMethods<'lua, T> { - fn box_method(method: M) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T, A) -> Result, - { - Box::new(move |lua, mut args| { - if let Some(front) = args.pop_front() { - let userdata = AnyUserData::from_lua(front, lua)?; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - - let type_id = lua.push_userdata_ref(&userdata.0)?; - match type_id { - Some(id) if id == TypeId::of::() => { - let ud = get_userdata_ref::(lua.state)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_lock().ok_or(Error::UserDataBorrowError)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_read().ok_or(Error::UserDataBorrowError)?; - method(lua, &ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - _ => Err(Error::UserDataTypeMismatch), - } - } - } else { - Err(Error::FromLuaConversionError { - from: "missing argument", - to: "userdata", - message: None, - }) - } - }) - } - - fn box_method_mut(method: M) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result, - { - let method = RefCell::new(method); - Box::new(move |lua, mut args| { - if let Some(front) = args.pop_front() { - let userdata = AnyUserData::from_lua(front, lua)?; - let mut method = method - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - - let type_id = lua.push_userdata_ref(&userdata.0)?; - match type_id { - Some(id) if id == TypeId::of::() => { - let mut ud = get_userdata_mut::(lua.state)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(lua.state)?; - let mut ud = ud - .try_borrow_mut() - .map_err(|_| Error::UserDataBorrowMutError)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(lua.state)?; - let mut ud = - ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(lua.state)?; - let mut ud = ud.try_lock().ok_or(Error::UserDataBorrowMutError)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(lua.state)?; - let mut ud = - ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(lua.state)?; - let mut ud = ud.try_write().ok_or(Error::UserDataBorrowMutError)?; - method(lua, &mut ud, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - } - _ => Err(Error::UserDataTypeMismatch), - } - } - } else { - Err(Error::FromLuaConversionError { - from: "missing argument", - to: "userdata", - message: None, - }) - } - }) - } - - #[cfg(feature = "async")] - fn box_async_method(method: M) -> AsyncCallback<'lua, 'static> - where - T: Clone, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, T, A) -> MR, - MR: 'lua + Future>, - { - Box::new(move |lua, mut args| { - let fut_res = || { - if let Some(front) = args.pop_front() { - let userdata = AnyUserData::from_lua(front, lua)?; - unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - - let type_id = lua.push_userdata_ref(&userdata.0)?; - match type_id { - Some(id) if id == TypeId::of::() => { - let ud = get_userdata_ref::(lua.state)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_lock().ok_or(Error::UserDataBorrowError)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(lua.state)?; - let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = - get_userdata_ref::>>(lua.state)?; - let ud = ud.try_read().ok_or(Error::UserDataBorrowError)?; - Ok(method(lua, ud.clone(), A::from_lua_multi(args, lua)?)) - } - _ => Err(Error::UserDataTypeMismatch), - } - } - } else { - Err(Error::FromLuaConversionError { - from: "missing argument", - to: "userdata", - message: None, - }) - } - }; - match fut_res() { - Ok(fut) => Box::pin(fut.and_then(move |ret| future::ready(ret.to_lua_multi(lua)))), - Err(e) => Box::pin(future::err(e)), - } - }) - } - - fn box_function(function: F) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - { - Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)) - } - - fn box_function_mut(function: F) -> Callback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, - { - let function = RefCell::new(function); - Box::new(move |lua, args| { - let function = &mut *function - .try_borrow_mut() - .map_err(|_| Error::RecursiveMutCallback)?; - function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) - }) - } - - #[cfg(feature = "async")] - fn box_async_function(function: F) -> AsyncCallback<'lua, 'static> - where - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, - { - Box::new(move |lua, args| { - let args = match A::from_lua_multi(args, lua) { - Ok(args) => args, - Err(e) => return Box::pin(future::err(e)), - }; - Box::pin(function(lua, args).and_then(move |ret| future::ready(ret.to_lua_multi(lua)))) - }) - } -} - -pub(crate) struct StaticUserDataFields<'lua, T: 'static + UserData> { - pub(crate) field_getters: Vec<(Vec, Callback<'lua, 'static>)>, - pub(crate) field_setters: Vec<(Vec, Callback<'lua, 'static>)>, - #[allow(clippy::type_complexity)] - pub(crate) meta_fields: Vec<( - MetaMethod, - Box Result> + 'static>, - )>, - _type: PhantomData, -} - -impl<'lua, T: 'static + UserData> Default for StaticUserDataFields<'lua, T> { - fn default() -> StaticUserDataFields<'lua, T> { - StaticUserDataFields { - field_getters: Vec::new(), - field_setters: Vec::new(), - meta_fields: Vec::new(), - _type: PhantomData, - } - } -} - -impl<'lua, T: 'static + UserData> UserDataFields<'lua, T> for StaticUserDataFields<'lua, T> { - fn add_field_method_get(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - M: 'static + MaybeSend + Fn(&'lua Lua, &T) -> Result, - { - self.field_getters.push(( - name.as_ref().to_vec(), - StaticUserDataMethods::box_method(move |lua, data, ()| method(lua, data)), - )); - } - - fn add_field_method_set(&mut self, name: &S, method: M) - where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - M: 'static + MaybeSend + FnMut(&'lua Lua, &mut T, A) -> Result<()>, - { - self.field_setters.push(( - name.as_ref().to_vec(), - StaticUserDataMethods::box_method_mut(method), - )); - } - - fn add_field_function_get(&mut self, name: &S, function: F) - where - S: AsRef<[u8]> + ?Sized, - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua, AnyUserData<'lua>) -> Result, - { - self.field_getters.push(( - name.as_ref().to_vec(), - StaticUserDataMethods::::box_function(function), - )); - } - - fn add_field_function_set(&mut self, name: &S, mut function: F) - where - S: AsRef<[u8]> + ?Sized, - A: FromLua<'lua>, - F: 'static + MaybeSend + FnMut(&'lua Lua, AnyUserData<'lua>, A) -> Result<()>, - { - self.field_setters.push(( - name.as_ref().to_vec(), - StaticUserDataMethods::::box_function_mut(move |lua, (data, val)| { - function(lua, data, val) - }), - )); - } - - fn add_meta_field_with(&mut self, meta: S, f: F) - where - S: Into, - R: ToLua<'lua>, - F: 'static + MaybeSend + Fn(&'lua Lua) -> Result, - { - let meta = meta.into(); - self.meta_fields.push(( - meta.clone(), - Box::new(move |lua| { - let value = f(lua)?.to_lua(lua)?; - if meta == MetaMethod::Index || meta == MetaMethod::NewIndex { - match value { - Value::Nil | Value::Table(_) | Value::Function(_) => {} - _ => { - return Err(Error::MetaMethodTypeError { - method: meta.to_string(), - type_name: value.type_name(), - message: Some("expected nil, table or function".to_string()), - }) - } - } - } - Ok(value) - }), - )); - } - - // Below are internal methods - - fn add_field_getter(&mut self, name: Vec, callback: Callback<'lua, 'static>) { - self.field_getters.push((name, callback)); - } - - fn add_field_setter(&mut self, name: Vec, callback: Callback<'lua, 'static>) { - self.field_setters.push((name, callback)); - } -} - -#[inline] -unsafe fn get_userdata_ref<'a, T>(state: *mut ffi::lua_State) -> Result> { - (*get_userdata::>(state, -1)).try_borrow() -} - -#[inline] -unsafe fn get_userdata_mut<'a, T>(state: *mut ffi::lua_State) -> Result> { - (*get_userdata::>(state, -1)).try_borrow_mut() -} - -macro_rules! lua_userdata_impl { - ($type:ty) => { - impl UserData for $type { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - let mut orig_fields = StaticUserDataFields::default(); - T::add_fields(&mut orig_fields); - for (name, callback) in orig_fields.field_getters { - fields.add_field_getter(name, callback); - } - for (name, callback) in orig_fields.field_setters { - fields.add_field_setter(name, callback); - } - } - - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - let mut orig_methods = StaticUserDataMethods::default(); - T::add_methods(&mut orig_methods); - for (name, callback) in orig_methods.methods { - methods.add_callback(name, callback); - } - #[cfg(feature = "async")] - for (name, callback) in orig_methods.async_methods { - methods.add_async_callback(name, callback); - } - for (meta, callback) in orig_methods.meta_methods { - methods.add_meta_callback(meta, callback); - } - #[cfg(feature = "async")] - for (meta, callback) in orig_methods.async_meta_methods { - methods.add_async_meta_callback(meta, callback); - } - } - } - }; -} - -#[cfg(not(feature = "send"))] -lua_userdata_impl!(Rc>); -lua_userdata_impl!(Arc>); -lua_userdata_impl!(Arc>); -#[cfg(feature = "parking_lot")] -lua_userdata_impl!(Arc>); -#[cfg(feature = "parking_lot")] -lua_userdata_impl!(Arc>); diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 063a5c6c..00000000 --- a/src/util.rs +++ /dev/null @@ -1,1005 +0,0 @@ -use std::any::{Any, TypeId}; -use std::ffi::CStr; -use std::fmt::Write; -use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; -use std::sync::Arc; -use std::{mem, ptr, slice}; - -use once_cell::sync::Lazy; -use rustc_hash::FxHashMap; - -use crate::error::{Error, Result}; -use crate::ffi; - -static METATABLE_CACHE: Lazy> = Lazy::new(|| { - let mut map = FxHashMap::with_capacity_and_hasher(32, Default::default()); - crate::lua::init_metatable_cache(&mut map); - map.insert(TypeId::of::(), 0); - map.insert(TypeId::of::(), 0); - map -}); - -// Checks that Lua has enough free stack space for future stack operations. On failure, this will -// panic with an internal error message. -#[inline] -pub unsafe fn assert_stack(state: *mut ffi::lua_State, amount: c_int) { - // TODO: This should only be triggered when there is a logic error in `mlua`. In the future, - // when there is a way to be confident about stack safety and test it, this could be enabled - // only when `cfg!(debug_assertions)` is true. - mlua_assert!( - ffi::lua_checkstack(state, amount) != 0, - "out of stack space" - ); -} - -// Checks that Lua has enough free stack space and returns `Error::StackError` on failure. -#[inline] -pub unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> Result<()> { - if ffi::lua_checkstack(state, amount) == 0 { - Err(Error::StackError) - } else { - Ok(()) - } -} - -pub struct StackGuard { - state: *mut ffi::lua_State, - top: c_int, - extra: c_int, -} - -impl StackGuard { - // Creates a StackGuard instance with record of the stack size, and on Drop will check the - // stack size and drop any extra elements. If the stack size at the end is *smaller* than at - // the beginning, this is considered a fatal logic error and will result in a panic. - #[inline] - pub unsafe fn new(state: *mut ffi::lua_State) -> StackGuard { - StackGuard { - state, - top: ffi::lua_gettop(state), - extra: 0, - } - } - - // Similar to `new`, but checks and keeps `extra` elements from top of the stack on Drop. - #[inline] - pub unsafe fn new_extra(state: *mut ffi::lua_State, extra: c_int) -> StackGuard { - StackGuard { - state, - top: ffi::lua_gettop(state), - extra, - } - } -} - -impl Drop for StackGuard { - fn drop(&mut self) { - unsafe { - let top = ffi::lua_gettop(self.state); - if top < self.top + self.extra { - mlua_panic!("{} too many stack values popped", self.top - top) - } - if top > self.top + self.extra { - if self.extra > 0 { - ffi::lua_rotate(self.state, self.top + 1, self.extra); - } - ffi::lua_settop(self.state, self.top + self.extra); - } - } - } -} - -// Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. -// Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a -// limited lua stack. `nargs` is the same as the the parameter to `lua_pcall`, and `nresults` is -// always `LUA_MULTRET`. Provided function must *not* panic, and since it will generally be lonjmping, -// should not contain any values that implements Drop. -// Internally uses 2 extra stack spaces, and does not call checkstack. -pub unsafe fn protect_lua_call( - state: *mut ffi::lua_State, - nargs: c_int, - f: unsafe extern "C" fn(*mut ffi::lua_State) -> c_int, -) -> Result<()> { - let stack_start = ffi::lua_gettop(state) - nargs; - - ffi::lua_pushcfunction(state, error_traceback); - ffi::lua_pushcfunction(state, f); - if nargs > 0 { - ffi::lua_rotate(state, stack_start + 1, 2); - } - - let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start + 1); - ffi::lua_remove(state, stack_start + 1); - - if ret == ffi::LUA_OK { - Ok(()) - } else { - Err(pop_error(state, ret)) - } -} - -// Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. -// Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a -// limited lua stack. `nargs` and `nresults` are similar to the parameters of `lua_pcall`, but the -// given function return type is not the return value count, instead the inner function return -// values are assumed to match the `nresults` param. Provided function must *not* panic, and since it -// will generally be lonjmping, should not contain any values that implements Drop. -// Internally uses 3 extra stack spaces, and does not call checkstack. -pub unsafe fn protect_lua_closure( - state: *mut ffi::lua_State, - nargs: c_int, - nresults: c_int, - f: F, -) -> Result -where - F: Fn(*mut ffi::lua_State) -> R, - R: Copy, -{ - union URes { - uninit: (), - init: R, - } - - struct Params { - function: F, - result: URes, - nresults: c_int, - } - - unsafe extern "C" fn do_call(state: *mut ffi::lua_State) -> c_int - where - R: Copy, - F: Fn(*mut ffi::lua_State) -> R, - { - let params = ffi::lua_touserdata(state, -1) as *mut Params; - ffi::lua_pop(state, 1); - - (*params).result.init = ((*params).function)(state); - - if (*params).nresults == ffi::LUA_MULTRET { - ffi::lua_gettop(state) - } else { - (*params).nresults - } - } - - let stack_start = ffi::lua_gettop(state) - nargs; - - ffi::lua_pushcfunction(state, error_traceback); - ffi::lua_pushcfunction(state, do_call::); - if nargs > 0 { - ffi::lua_rotate(state, stack_start + 1, 2); - } - - let mut params = Params { - function: f, - result: URes { uninit: () }, - nresults, - }; - - ffi::lua_pushlightuserdata(state, &mut params as *mut Params as *mut c_void); - let ret = ffi::lua_pcall(state, nargs + 1, nresults, stack_start + 1); - ffi::lua_remove(state, stack_start + 1); - - if ret == ffi::LUA_OK { - // `LUA_OK` is only returned when the `do_call` function has completed successfully, so - // `params.result` is definitely initialized. - Ok(params.result.init) - } else { - Err(pop_error(state, ret)) - } -} - -// Pops an error off of the stack and returns it. The specific behavior depends on the type of the -// error at the top of the stack: -// 1) If the error is actually a WrappedPanic, this will continue the panic. -// 2) If the error on the top of the stack is actually a WrappedError, just returns it. -// 3) Otherwise, interprets the error as the appropriate lua error. -// Uses 2 stack spaces, does not call checkstack. -pub unsafe fn pop_error(state: *mut ffi::lua_State, err_code: c_int) -> Error { - mlua_debug_assert!( - err_code != ffi::LUA_OK && err_code != ffi::LUA_YIELD, - "pop_error called with non-error return code" - ); - - match get_gc_userdata::(state, -1).as_mut() { - Some(WrappedFailure::Error(err)) => { - ffi::lua_pop(state, 1); - err.clone() - } - Some(WrappedFailure::Panic(panic)) => { - if let Some(p) = panic.take() { - resume_unwind(p); - } else { - Error::PreviouslyResumedPanic - } - } - _ => { - let err_string = to_string(state, -1); - ffi::lua_pop(state, 1); - - match err_code { - ffi::LUA_ERRRUN => Error::RuntimeError(err_string), - ffi::LUA_ERRSYNTAX => { - Error::SyntaxError { - // This seems terrible, but as far as I can tell, this is exactly what the - // stock Lua REPL does. - incomplete_input: err_string.ends_with("") - || err_string.ends_with("''"), - message: err_string, - } - } - ffi::LUA_ERRERR => { - // This error is raised when the error handler raises an error too many times - // recursively, and continuing to trigger the error handler would cause a stack - // overflow. It is not very useful to differentiate between this and "ordinary" - // runtime errors, so we handle them the same way. - Error::RuntimeError(err_string) - } - ffi::LUA_ERRMEM => Error::MemoryError(err_string), - #[cfg(any(feature = "lua53", feature = "lua52"))] - ffi::LUA_ERRGCMM => Error::GarbageCollectorError(err_string), - _ => mlua_panic!("unrecognized lua error code"), - } - } - } -} - -// Uses 3 stack spaces, does not call checkstack. -#[inline] -pub unsafe fn push_string + ?Sized>( - state: *mut ffi::lua_State, - s: &S, -) -> Result<()> { - let s = s.as_ref(); - protect_lua!(state, 0, 1, |state| { - ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()); - }) -} - -// Uses 3 stack spaces, does not call checkstack. -#[inline] -pub unsafe fn push_table(state: *mut ffi::lua_State, narr: c_int, nrec: c_int) -> Result<()> { - protect_lua!(state, 0, 1, |state| ffi::lua_createtable(state, narr, nrec)) -} - -// Uses 4 stack spaces, does not call checkstack. -pub unsafe fn rawset_field(state: *mut ffi::lua_State, table: c_int, field: &S) -> Result<()> -where - S: AsRef<[u8]> + ?Sized, -{ - let field = field.as_ref(); - ffi::lua_pushvalue(state, table); - protect_lua!(state, 2, 0, |state| { - ffi::lua_pushlstring(state, field.as_ptr() as *const c_char, field.len()); - ffi::lua_rotate(state, -3, 2); - ffi::lua_rawset(state, -3); - }) -} - -// Internally uses 3 stack spaces, does not call checkstack. -#[cfg(not(feature = "luau"))] -#[inline] -pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { - let ud = protect_lua!(state, 0, 1, |state| { - ffi::lua_newuserdata(state, mem::size_of::()) as *mut T - })?; - ptr::write(ud, t); - Ok(()) -} - -// Internally uses 3 stack spaces, does not call checkstack. -#[cfg(feature = "luau")] -#[inline] -pub unsafe fn push_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { - unsafe extern "C" fn destructor(ud: *mut c_void) { - let ud = ud as *mut T; - if *(ud.offset(1) as *mut u8) == 0 { - ptr::drop_in_place(ud); - } - } - - let ud = protect_lua!(state, 0, 1, |state| { - let size = mem::size_of::() + 1; - ffi::lua_newuserdatadtor(state, size, destructor::) as *mut T - })?; - ptr::write(ud, t); - *(ud.offset(1) as *mut u8) = 0; // Mark as not destructed - - Ok(()) -} - -// Internally uses 3 stack spaces, does not call checkstack. -#[cfg(feature = "lua54")] -#[inline] -pub unsafe fn push_userdata_uv(state: *mut ffi::lua_State, t: T, nuvalue: c_int) -> Result<()> { - let ud = protect_lua!(state, 0, 1, |state| { - ffi::lua_newuserdatauv(state, mem::size_of::(), nuvalue) as *mut T - })?; - ptr::write(ud, t); - Ok(()) -} - -#[inline] -pub unsafe fn get_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { - let ud = ffi::lua_touserdata(state, index) as *mut T; - mlua_debug_assert!(!ud.is_null(), "userdata pointer is null"); - ud -} - -// Pops the userdata off of the top of the stack and returns it to rust, invalidating the lua -// userdata and gives it the special "destructed" userdata metatable. Userdata must not have been -// previously invalidated, and this method does not check for this. -// Uses 1 extra stack space and does not call checkstack. -pub unsafe fn take_userdata(state: *mut ffi::lua_State) -> T { - // We set the metatable of userdata on __gc to a special table with no __gc method and with - // metamethods that trigger an error on access. We do this so that it will not be double - // dropped, and also so that it cannot be used or identified as any particular userdata type - // after the first call to __gc. - get_destructed_userdata_metatable(state); - ffi::lua_setmetatable(state, -2); - let ud = get_userdata::(state, -1); - ffi::lua_pop(state, 1); - if cfg!(feature = "luau") { - *(ud.offset(1) as *mut u8) = 1; // Mark as destructed - } - ptr::read(ud) -} - -// Pushes the userdata and attaches a metatable with __gc method. -// Internally uses 3 stack spaces, does not call checkstack. -pub unsafe fn push_gc_userdata(state: *mut ffi::lua_State, t: T) -> Result<()> { - push_userdata(state, t)?; - get_gc_metatable::(state); - ffi::lua_setmetatable(state, -2); - Ok(()) -} - -// Uses 2 stack spaces, does not call checkstack -pub unsafe fn get_gc_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { - let ud = ffi::lua_touserdata(state, index) as *mut T; - if ud.is_null() || ffi::lua_getmetatable(state, index) == 0 { - return ptr::null_mut(); - } - get_gc_metatable::(state); - let res = ffi::lua_rawequal(state, -1, -2); - ffi::lua_pop(state, 2); - if res == 0 { - return ptr::null_mut(); - } - ud -} - -unsafe extern "C" fn lua_error_impl(state: *mut ffi::lua_State) -> c_int { - ffi::lua_error(state); -} - -unsafe extern "C" fn lua_isfunction_impl(state: *mut ffi::lua_State) -> c_int { - let t = ffi::lua_type(state, -1); - ffi::lua_pop(state, 1); - ffi::lua_pushboolean(state, (t == ffi::LUA_TFUNCTION) as c_int); - 1 -} - -unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> { - let index_key = &USERDATA_METATABLE_INDEX as *const u8 as *const _; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, index_key) == ffi::LUA_TFUNCTION { - return Ok(()); - } - ffi::lua_pop(state, 1); - - // Create and cache `__index` helper - let code = cstr!( - r#" - local error, isfunction = ... - return function (__index, field_getters, methods) - return function (self, key) - if field_getters ~= nil then - local field_getter = field_getters[key] - if field_getter ~= nil then - return field_getter(self) - end - end - - if methods ~= nil then - local method = methods[key] - if method ~= nil then - return method - end - end - - if isfunction(__index) then - return __index(self, key) - elseif __index == nil then - error("attempt to get an unknown field '"..key.."'") - else - return __index[key] - end - end - end - "# - ); - let code_len = CStr::from_ptr(code).to_bytes().len(); - protect_lua!(state, 0, 1, |state| { - let ret = ffi::luaL_loadbuffer(state, code, code_len, cstr!("__mlua_index")); - if ret != ffi::LUA_OK { - ffi::lua_error(state); - } - ffi::lua_pushcfunction(state, lua_error_impl); - ffi::lua_pushcfunction(state, lua_isfunction_impl); - ffi::lua_call(state, 2, 1); - - // Store in the registry - ffi::lua_pushvalue(state, -1); - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, index_key); - }) -} - -pub unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> { - let newindex_key = &USERDATA_METATABLE_NEWINDEX as *const u8 as *const _; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, newindex_key) == ffi::LUA_TFUNCTION { - return Ok(()); - } - ffi::lua_pop(state, 1); - - // Create and cache `__newindex` helper - let code = cstr!( - r#" - local error, isfunction = ... - return function (__newindex, field_setters) - return function (self, key, value) - if field_setters ~= nil then - local field_setter = field_setters[key] - if field_setter ~= nil then - field_setter(self, value) - return - end - end - - if isfunction(__newindex) then - __newindex(self, key, value) - elseif __newindex == nil then - error("attempt to set an unknown field '"..key.."'") - else - __newindex[key] = value - end - end - end - "# - ); - let code_len = CStr::from_ptr(code).to_bytes().len(); - protect_lua!(state, 0, 1, |state| { - let ret = ffi::luaL_loadbuffer(state, code, code_len, cstr!("__mlua_newindex")); - if ret != ffi::LUA_OK { - ffi::lua_error(state); - } - ffi::lua_pushcfunction(state, lua_error_impl); - ffi::lua_pushcfunction(state, lua_isfunction_impl); - ffi::lua_call(state, 2, 1); - - // Store in the registry - ffi::lua_pushvalue(state, -1); - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, newindex_key); - }) -} - -// Populates the given table with the appropriate members to be a userdata metatable for the given type. -// This function takes the given table at the `metatable` index, and adds an appropriate `__gc` member -// to it for the given type and a `__metatable` entry to protect the table from script access. -// The function also, if given a `field_getters` or `methods` tables, will create an `__index` metamethod -// (capturing previous one) to lookup in `field_getters` first, then `methods` and falling back to the -// captured `__index` if no matches found. -// The same is also applicable for `__newindex` metamethod and `field_setters` table. -// Internally uses 9 stack spaces and does not call checkstack. -pub unsafe fn init_userdata_metatable( - state: *mut ffi::lua_State, - metatable: c_int, - field_getters: Option, - field_setters: Option, - methods: Option, -) -> Result<()> { - ffi::lua_pushvalue(state, metatable); - - if field_getters.is_some() || methods.is_some() { - // Push `__index` generator function - init_userdata_metatable_index(state)?; - - push_string(state, "__index")?; - let index_type = ffi::lua_rawget(state, -3); - match index_type { - ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { - for &idx in &[field_getters, methods] { - if let Some(idx) = idx { - ffi::lua_pushvalue(state, idx); - } else { - ffi::lua_pushnil(state); - } - } - - // Generate `__index` - protect_lua!(state, 4, 1, fn(state) ffi::lua_call(state, 3, 1))?; - } - _ => mlua_panic!("improper __index type {}", index_type), - } - - rawset_field(state, -2, "__index")?; - } - - if let Some(field_setters) = field_setters { - // Push `__newindex` generator function - init_userdata_metatable_newindex(state)?; - - push_string(state, "__newindex")?; - let newindex_type = ffi::lua_rawget(state, -3); - match newindex_type { - ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { - ffi::lua_pushvalue(state, field_setters); - // Generate `__newindex` - protect_lua!(state, 3, 1, fn(state) ffi::lua_call(state, 2, 1))?; - } - _ => mlua_panic!("improper __newindex type {}", newindex_type), - } - - rawset_field(state, -2, "__newindex")?; - } - - #[cfg(not(feature = "luau"))] - { - ffi::lua_pushcfunction(state, userdata_destructor::); - rawset_field(state, -2, "__gc")?; - } - - ffi::lua_pushboolean(state, 0); - rawset_field(state, -2, "__metatable")?; - - ffi::lua_pop(state, 1); - - Ok(()) -} - -#[cfg(not(feature = "luau"))] -pub unsafe extern "C" fn userdata_destructor(state: *mut ffi::lua_State) -> c_int { - // It's probably NOT a good idea to catch Rust panics in finalizer - // Lua 5.4 ignores it, other versions generates `LUA_ERRGCMM` without calling message handler - take_userdata::(state); - 0 -} - -// In the context of a lua callback, this will call the given function and if the given function -// returns an error, *or if the given function panics*, this will result in a call to `lua_error` (a -// longjmp). The error or panic is wrapped in such a way that when calling `pop_error` back on -// the Rust side, it will resume the panic. -// -// This function assumes the structure of the stack at the beginning of a callback, that the only -// elements on the stack are the arguments to the callback. -// -// This function uses some of the bottom of the stack for error handling, the given callback will be -// given the number of arguments available as an argument, and should return the number of returns -// as normal, but cannot assume that the arguments available start at 0. -pub unsafe fn callback_error(state: *mut ffi::lua_State, f: F) -> R -where - F: FnOnce(c_int) -> Result, -{ - let nargs = ffi::lua_gettop(state); - - // We need 2 extra stack spaces to store preallocated memory and error/panic metatable - let extra_stack = if nargs < 2 { 2 - nargs } else { 1 }; - ffi::luaL_checkstack( - state, - extra_stack, - cstr!("not enough stack space for callback error handling"), - ); - - // We cannot shadow Rust errors with Lua ones, we pre-allocate enough memory - // to store a wrapped error or panic *before* we proceed. - let ud = WrappedFailure::new_userdata(state); - ffi::lua_rotate(state, 1, 1); - - match catch_unwind(AssertUnwindSafe(|| f(nargs))) { - Ok(Ok(r)) => { - ffi::lua_remove(state, 1); - r - } - Ok(Err(err)) => { - ffi::lua_settop(state, 1); - - let wrapped_error = ud as *mut WrappedFailure; - - // Build `CallbackError` with traceback - let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { - ffi::luaL_traceback(state, state, ptr::null(), 0); - let traceback = to_string(state, -1); - ffi::lua_pop(state, 1); - traceback - } else { - "".to_string() - }; - let cause = Arc::new(err); - ptr::write( - wrapped_error, - WrappedFailure::Error(Error::CallbackError { traceback, cause }), - ); - get_gc_metatable::(state); - ffi::lua_setmetatable(state, -2); - - ffi::lua_error(state) - } - Err(p) => { - ffi::lua_settop(state, 1); - ptr::write(ud as *mut WrappedFailure, WrappedFailure::Panic(Some(p))); - get_gc_metatable::(state); - ffi::lua_setmetatable(state, -2); - ffi::lua_error(state) - } - } -} - -pub unsafe extern "C" fn error_traceback(state: *mut ffi::lua_State) -> c_int { - if ffi::lua_checkstack(state, 2) == 0 { - // If we don't have enough stack space to even check the error type, do - // nothing so we don't risk shadowing a rust panic. - return 1; - } - - if get_gc_userdata::(state, -1).is_null() { - let s = ffi::luaL_tolstring(state, -1, ptr::null_mut()); - if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { - ffi::luaL_traceback(state, state, s, 0); - ffi::lua_remove(state, -2); - } - } - - 1 -} - -// A variant of `pcall` that does not allow Lua to catch Rust panics from `callback_error`. -pub unsafe extern "C" fn safe_pcall(state: *mut ffi::lua_State) -> c_int { - ffi::luaL_checkstack(state, 2, ptr::null()); - - let top = ffi::lua_gettop(state); - if top == 0 { - ffi::lua_pushstring(state, cstr!("not enough arguments to pcall")); - ffi::lua_error(state); - } - - if ffi::lua_pcall(state, top - 1, ffi::LUA_MULTRET, 0) == ffi::LUA_OK { - ffi::lua_pushboolean(state, 1); - ffi::lua_insert(state, 1); - ffi::lua_gettop(state) - } else { - if let Some(WrappedFailure::Panic(_)) = - get_gc_userdata::(state, -1).as_ref() - { - ffi::lua_error(state); - } - ffi::lua_pushboolean(state, 0); - ffi::lua_insert(state, -2); - 2 - } -} - -// A variant of `xpcall` that does not allow Lua to catch Rust panics from `callback_error`. -pub unsafe extern "C" fn safe_xpcall(state: *mut ffi::lua_State) -> c_int { - unsafe extern "C" fn xpcall_msgh(state: *mut ffi::lua_State) -> c_int { - ffi::luaL_checkstack(state, 2, ptr::null()); - - if let Some(WrappedFailure::Panic(_)) = - get_gc_userdata::(state, -1).as_ref() - { - 1 - } else { - ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); - ffi::lua_insert(state, 1); - ffi::lua_call(state, ffi::lua_gettop(state) - 1, ffi::LUA_MULTRET); - ffi::lua_gettop(state) - } - } - - ffi::luaL_checkstack(state, 2, ptr::null()); - - let top = ffi::lua_gettop(state); - if top < 2 { - ffi::lua_pushstring(state, cstr!("not enough arguments to xpcall")); - ffi::lua_error(state); - } - - ffi::lua_pushvalue(state, 2); - ffi::lua_pushcclosure(state, xpcall_msgh, 1); - ffi::lua_copy(state, 1, 2); - ffi::lua_replace(state, 1); - - if ffi::lua_pcall(state, ffi::lua_gettop(state) - 2, ffi::LUA_MULTRET, 1) == ffi::LUA_OK { - ffi::lua_pushboolean(state, 1); - ffi::lua_insert(state, 2); - ffi::lua_gettop(state) - 1 - } else { - if let Some(WrappedFailure::Panic(_)) = - get_gc_userdata::(state, -1).as_ref() - { - ffi::lua_error(state); - } - ffi::lua_pushboolean(state, 0); - ffi::lua_insert(state, -2); - 2 - } -} - -// Returns Lua main thread for Lua >= 5.2 or checks that the passed thread is main for Lua 5.1. -// Does not call lua_checkstack, uses 1 stack space. -pub unsafe fn get_main_state(state: *mut ffi::lua_State) -> Option<*mut ffi::lua_State> { - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - { - ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD); - let main_state = ffi::lua_tothread(state, -1); - ffi::lua_pop(state, 1); - Some(main_state) - } - #[cfg(any(feature = "lua51", feature = "luajit"))] - { - // Check the current state first - let is_main_state = ffi::lua_pushthread(state) == 1; - ffi::lua_pop(state, 1); - if is_main_state { - Some(state) - } else { - None - } - } - #[cfg(feature = "luau")] - Some(ffi::lua_mainthread(state)) -} - -// Initialize the internal (with __gc method) metatable for a type T. -// Uses 6 stack spaces and calls checkstack. -pub unsafe fn init_gc_metatable( - state: *mut ffi::lua_State, - customize_fn: Option Result<()>>, -) -> Result<()> { - check_stack(state, 6)?; - - push_table(state, 0, 3)?; - - #[cfg(not(feature = "luau"))] - { - ffi::lua_pushcfunction(state, userdata_destructor::); - rawset_field(state, -2, "__gc")?; - } - - ffi::lua_pushboolean(state, 0); - rawset_field(state, -2, "__metatable")?; - - if let Some(f) = customize_fn { - f(state)?; - } - - let type_id = TypeId::of::(); - let ref_addr = &METATABLE_CACHE[&type_id] as *const u8; - protect_lua!(state, 1, 0, |state| { - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, ref_addr as *const c_void); - })?; - - Ok(()) -} - -pub unsafe fn get_gc_metatable(state: *mut ffi::lua_State) { - let type_id = TypeId::of::(); - let ref_addr = - mlua_expect!(METATABLE_CACHE.get(&type_id), "gc metatable does not exist") as *const u8; - ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, ref_addr as *const c_void); -} - -// Initialize the error, panic, and destructed userdata metatables. -pub unsafe fn init_error_registry(state: *mut ffi::lua_State) -> Result<()> { - check_stack(state, 7)?; - - // Create error and panic metatables - - unsafe extern "C" fn error_tostring(state: *mut ffi::lua_State) -> c_int { - callback_error(state, |_| { - check_stack(state, 3)?; - - let err_buf = match get_gc_userdata::(state, -1).as_ref() { - Some(WrappedFailure::Error(error)) => { - let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; - ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); - let err_buf = ffi::lua_touserdata(state, -1) as *mut String; - ffi::lua_pop(state, 2); - - (*err_buf).clear(); - // Depending on how the API is used and what error types scripts are given, it may - // be possible to make this consume arbitrary amounts of memory (for example, some - // kind of recursive error structure?) - let _ = write!(&mut (*err_buf), "{}", error); - Ok(err_buf) - } - Some(WrappedFailure::Panic(Some(ref panic))) => { - let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; - ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); - let err_buf = ffi::lua_touserdata(state, -1) as *mut String; - (*err_buf).clear(); - ffi::lua_pop(state, 2); - - if let Some(msg) = panic.downcast_ref::<&str>() { - let _ = write!(&mut (*err_buf), "{}", msg); - } else if let Some(msg) = panic.downcast_ref::() { - let _ = write!(&mut (*err_buf), "{}", msg); - } else { - let _ = write!(&mut (*err_buf), ""); - }; - Ok(err_buf) - } - Some(WrappedFailure::Panic(None)) => Err(Error::PreviouslyResumedPanic), - _ => { - // I'm not sure whether this is possible to trigger without bugs in mlua? - Err(Error::UserDataTypeMismatch) - } - }?; - - push_string(state, &*err_buf)?; - (*err_buf).clear(); - - Ok(1) - }) - } - - init_gc_metatable::( - state, - Some(|state| { - ffi::lua_pushcfunction(state, error_tostring); - rawset_field(state, -2, "__tostring") - }), - )?; - - // Create destructed userdata metatable - - unsafe extern "C" fn destructed_error(state: *mut ffi::lua_State) -> c_int { - callback_error(state, |_| Err(Error::CallbackDestructed)) - } - - push_table(state, 0, 26)?; - ffi::lua_pushcfunction(state, destructed_error); - for &method in &[ - "__add", - "__sub", - "__mul", - "__div", - "__mod", - "__pow", - "__unm", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__idiv", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__band", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bor", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bxor", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__bnot", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__shl", - #[cfg(any(feature = "lua54", feature = "lua53"))] - "__shr", - "__concat", - "__len", - "__eq", - "__lt", - "__le", - "__index", - "__newindex", - "__call", - "__tostring", - #[cfg(any( - feature = "lua54", - feature = "lua53", - feature = "lua52", - feature = "luajit52" - ))] - "__pairs", - #[cfg(any(feature = "lua53", feature = "lua52", feature = "luajit52"))] - "__ipairs", - #[cfg(feature = "lua54")] - "__close", - ] { - ffi::lua_pushvalue(state, -1); - rawset_field(state, -3, method)?; - } - ffi::lua_pop(state, 1); - - protect_lua!(state, 1, 0, fn(state) { - let destructed_mt_key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, destructed_mt_key); - })?; - - // Create error print buffer - init_gc_metatable::(state, None)?; - push_gc_userdata(state, String::new())?; - protect_lua!(state, 1, 0, fn(state) { - let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); - })?; - - Ok(()) -} - -pub(crate) enum WrappedFailure { - None, - Error(Error), - Panic(Option>), -} - -impl WrappedFailure { - pub(crate) unsafe fn new_userdata(state: *mut ffi::lua_State) -> *mut Self { - let size = mem::size_of::(); - #[cfg(feature = "luau")] - let ud = { - unsafe extern "C" fn destructor(p: *mut c_void) { - ptr::drop_in_place(p as *mut WrappedFailure); - } - ffi::lua_newuserdatadtor(state, size, destructor) as *mut Self - }; - #[cfg(not(feature = "luau"))] - let ud = ffi::lua_newuserdata(state, size) as *mut Self; - ptr::write(ud, WrappedFailure::None); - ud - } -} - -// Converts the given lua value to a string in a reasonable format without causing a Lua error or -// panicking. -pub(crate) unsafe fn to_string(state: *mut ffi::lua_State, index: c_int) -> String { - match ffi::lua_type(state, index) { - ffi::LUA_TNONE => "".to_string(), - ffi::LUA_TNIL => "".to_string(), - ffi::LUA_TBOOLEAN => (ffi::lua_toboolean(state, index) != 1).to_string(), - ffi::LUA_TLIGHTUSERDATA => { - format!("", ffi::lua_topointer(state, index)) - } - ffi::LUA_TNUMBER => { - let mut isint = 0; - let i = ffi::lua_tointegerx(state, -1, &mut isint); - if isint == 0 { - ffi::lua_tonumber(state, index).to_string() - } else { - i.to_string() - } - } - #[cfg(feature = "luau")] - ffi::LUA_TVECTOR => { - let v = ffi::lua_tovector(state, index); - mlua_debug_assert!(!v.is_null(), "vector is null"); - let (x, y, z) = (*v, *v.add(1), *v.add(2)); - format!("vector({},{},{})", x, y, z) - } - ffi::LUA_TSTRING => { - let mut size = 0; - // This will not trigger a 'm' error, because the reference is guaranteed to be of - // string type - let data = ffi::lua_tolstring(state, index, &mut size); - String::from_utf8_lossy(slice::from_raw_parts(data as *const u8, size)).into_owned() - } - ffi::LUA_TTABLE => format!("
", ffi::lua_topointer(state, index)), - ffi::LUA_TFUNCTION => format!("", ffi::lua_topointer(state, index)), - ffi::LUA_TUSERDATA => format!("", ffi::lua_topointer(state, index)), - ffi::LUA_TTHREAD => format!("", ffi::lua_topointer(state, index)), - _ => "".to_string(), - } -} - -pub(crate) unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_State) { - let key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; - ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, key); -} - -pub(crate) unsafe fn ptr_to_cstr_bytes<'a>(input: *const c_char) -> Option<&'a [u8]> { - if input.is_null() { - return None; - } - Some(CStr::from_ptr(input).to_bytes()) -} - -static DESTRUCTED_USERDATA_METATABLE: u8 = 0; -static ERROR_PRINT_BUFFER_KEY: u8 = 0; -static USERDATA_METATABLE_INDEX: u8 = 0; -static USERDATA_METATABLE_NEWINDEX: u8 = 0; diff --git a/src/util/error.rs b/src/util/error.rs new file mode 100644 index 00000000..c84902de --- /dev/null +++ b/src/util/error.rs @@ -0,0 +1,435 @@ +use std::any::Any; +use std::fmt::Write as _; +use std::mem::MaybeUninit; +use std::os::raw::{c_int, c_void}; +use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; +use std::ptr; +use std::sync::Arc; + +use crate::error::{Error, Result}; +use crate::memory::MemoryState; +use crate::util::{ + DESTRUCTED_USERDATA_METATABLE, TypeKey, check_stack, get_internal_userdata, init_internal_metatable, + push_internal_userdata, push_string, push_table, rawset_field, to_string, +}; + +static WRAPPED_FAILURE_TYPE_KEY: u8 = 0; + +pub(crate) enum WrappedFailure { + None, + Error(Error), + Panic(Option>), +} + +impl TypeKey for WrappedFailure { + #[inline(always)] + fn type_key() -> *const c_void { + &WRAPPED_FAILURE_TYPE_KEY as *const u8 as *const c_void + } +} + +impl WrappedFailure { + pub(crate) unsafe fn new_userdata(state: *mut ffi::lua_State) -> *mut Self { + // Unprotected calls always return `Ok` + push_internal_userdata(state, WrappedFailure::None, false).unwrap() + } +} + +// In the context of a lua callback, this will call the given function and if the given function +// returns an error, *or if the given function panics*, this will result in a call to `lua_error` (a +// longjmp). The error or panic is wrapped in such a way that when calling `pop_error` back on +// the Rust side, it will resume the panic. +// +// This function assumes the structure of the stack at the beginning of a callback, that the only +// elements on the stack are the arguments to the callback. +// +// This function uses some of the bottom of the stack for error handling, the given callback will be +// given the number of arguments available as an argument, and should return the number of returns +// as normal, but cannot assume that the arguments available start at 0. +unsafe fn callback_error(state: *mut ffi::lua_State, f: F) -> R +where + F: FnOnce(c_int) -> Result, +{ + let nargs = ffi::lua_gettop(state); + + // We need 2 extra stack spaces to store preallocated memory and error/panic metatable + let extra_stack = if nargs < 2 { 2 - nargs } else { 1 }; + ffi::luaL_checkstack( + state, + extra_stack, + cstr!("not enough stack space for callback error handling"), + ); + + // We cannot shadow Rust errors with Lua ones, we pre-allocate enough memory + // to store a wrapped error or panic *before* we proceed. + let ud = WrappedFailure::new_userdata(state); + ffi::lua_rotate(state, 1, 1); + + match catch_unwind(AssertUnwindSafe(|| f(nargs))) { + Ok(Ok(r)) => { + ffi::lua_remove(state, 1); + r + } + Ok(Err(err)) => { + ffi::lua_settop(state, 1); + + // Build `CallbackError` with traceback + let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, ptr::null(), 0); + let traceback = to_string(state, -1); + ffi::lua_pop(state, 1); + traceback + } else { + "".to_string() + }; + let cause = Arc::new(err); + let wrapped_error = WrappedFailure::Error(Error::CallbackError { traceback, cause }); + ptr::write(ud, wrapped_error); + ffi::lua_error(state) + } + Err(p) => { + ffi::lua_settop(state, 1); + ptr::write(ud, WrappedFailure::Panic(Some(p))); + ffi::lua_error(state) + } + } +} + +// Pops an error off of the stack and returns it. The specific behavior depends on the type of the +// error at the top of the stack: +// 1) If the error is actually a panic, this will continue the panic. +// 2) If the error on the top of the stack is actually an error, just returns it. +// 3) Otherwise, interprets the error as the appropriate lua error. +// Uses 2 stack spaces, does not call checkstack. +pub(crate) unsafe fn pop_error(state: *mut ffi::lua_State, err_code: c_int) -> Error { + mlua_debug_assert!( + err_code != ffi::LUA_OK && err_code != ffi::LUA_YIELD, + "pop_error called with non-error return code" + ); + + match get_internal_userdata::(state, -1, ptr::null()).as_mut() { + Some(WrappedFailure::Error(err)) => { + ffi::lua_pop(state, 1); + err.clone() + } + Some(WrappedFailure::Panic(panic)) => { + if let Some(p) = panic.take() { + resume_unwind(p); + } else { + Error::PreviouslyResumedPanic + } + } + _ => { + let err_string = to_string(state, -1); + ffi::lua_pop(state, 1); + + match err_code { + ffi::LUA_ERRRUN => Error::RuntimeError(err_string), + ffi::LUA_ERRSYNTAX => { + Error::SyntaxError { + // This seems terrible, but as far as I can tell, this is exactly what the + // stock Lua REPL does. + incomplete_input: err_string.ends_with("") || err_string.ends_with("''"), + message: err_string, + } + } + ffi::LUA_ERRERR => { + // This error is raised when the error handler raises an error too many times + // recursively, and continuing to trigger the error handler would cause a stack + // overflow. It is not very useful to differentiate between this and "ordinary" + // runtime errors, so we handle them the same way. + Error::RuntimeError(err_string) + } + ffi::LUA_ERRMEM => Error::MemoryError(err_string), + #[cfg(any(feature = "lua53", feature = "lua52"))] + ffi::LUA_ERRGCMM => Error::GarbageCollectorError(err_string), + _ => mlua_panic!("unrecognized lua error code"), + } + } + } +} + +// Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. +// Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a +// limited lua stack. `nargs` is the same as the the parameter to `lua_pcall`, and `nresults` is +// always `LUA_MULTRET`. Provided function must *not* panic, and since it will generally be +// longjmping, should not contain any values that implements Drop. +// Internally uses 2 extra stack spaces, and does not call checkstack. +pub(crate) unsafe fn protect_lua_call( + state: *mut ffi::lua_State, + nargs: c_int, + f: unsafe extern "C-unwind" fn(*mut ffi::lua_State) -> c_int, +) -> Result<()> { + let stack_start = ffi::lua_gettop(state) - nargs; + + MemoryState::relax_limit_with(state, || { + ffi::lua_pushcfunction(state, error_traceback); + ffi::lua_pushcfunction(state, f); + }); + if nargs > 0 { + ffi::lua_rotate(state, stack_start + 1, 2); + } + + let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start + 1); + ffi::lua_remove(state, stack_start + 1); + + if ret == ffi::LUA_OK { + Ok(()) + } else { + Err(pop_error(state, ret)) + } +} + +// Call a function that calls into the Lua API and may trigger a Lua error (longjmp) in a safe way. +// Wraps the inner function in a call to `lua_pcall`, so the inner function only has access to a +// limited lua stack. `nargs` and `nresults` are similar to the parameters of `lua_pcall`, but the +// given function return type is not the return value count, instead the inner function return +// values are assumed to match the `nresults` param. Provided function must *not* panic, and since +// it will generally be longjmping, should not contain any values that implements Drop. +// Internally uses 3 extra stack spaces, and does not call checkstack. +pub(crate) unsafe fn protect_lua_closure( + state: *mut ffi::lua_State, + nargs: c_int, + nresults: c_int, + f: F, +) -> Result +where + F: FnOnce(*mut ffi::lua_State) -> R, + R: Copy, +{ + struct Params { + function: Option, + result: MaybeUninit, + nresults: c_int, + } + + unsafe extern "C-unwind" fn do_call(state: *mut ffi::lua_State) -> c_int + where + F: FnOnce(*mut ffi::lua_State) -> R, + R: Copy, + { + let params = ffi::lua_tolightuserdata(state, -1) as *mut Params; + ffi::lua_pop(state, 1); + + let f = (*params).function.take().unwrap(); + (*params).result.write(f(state)); + + if (*params).nresults == ffi::LUA_MULTRET { + ffi::lua_gettop(state) + } else { + (*params).nresults + } + } + + let stack_start = ffi::lua_gettop(state) - nargs; + + MemoryState::relax_limit_with(state, || { + ffi::lua_pushcfunction(state, error_traceback); + ffi::lua_pushcfunction(state, do_call::); + }); + if nargs > 0 { + ffi::lua_rotate(state, stack_start + 1, 2); + } + + let mut params = Params { + function: Some(f), + result: MaybeUninit::uninit(), + nresults, + }; + + ffi::lua_pushlightuserdata(state, &mut params as *mut Params as *mut c_void); + let ret = ffi::lua_pcall(state, nargs + 1, nresults, stack_start + 1); + ffi::lua_remove(state, stack_start + 1); // remove error handler + + if ret == ffi::LUA_OK { + // `LUA_OK` is only returned when the `do_call` function has completed successfully, so + // `params.result` is definitely initialized. + Ok(params.result.assume_init()) + } else { + Err(pop_error(state, ret)) + } +} + +pub(crate) unsafe extern "C-unwind" fn error_traceback(state: *mut ffi::lua_State) -> c_int { + // Luau calls error handler for memory allocation errors, skip it + // See https://github.com/luau-lang/luau/issues/880 + #[cfg(feature = "luau")] + if MemoryState::limit_reached(state) { + return 0; + } + + if ffi::lua_checkstack(state, 2) == 0 { + // If we don't have enough stack space to even check the error type, do + // nothing so we don't risk shadowing a rust panic. + return 1; + } + + if get_internal_userdata::(state, -1, ptr::null()).is_null() { + let s = ffi::luaL_tolstring(state, -1, ptr::null_mut()); + if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, state, s, 0); + ffi::lua_remove(state, -2); + } + } + + 1 +} + +// A variant of `error_traceback` that can safely inspect another (yielded) thread stack +pub(crate) unsafe fn error_traceback_thread(state: *mut ffi::lua_State, thread: *mut ffi::lua_State) { + // Move error object to the main thread to safely call `__tostring` metamethod if present + ffi::lua_xmove(thread, state, 1); + + if get_internal_userdata::(state, -1, ptr::null()).is_null() { + let s = ffi::luaL_tolstring(state, -1, ptr::null_mut()); + if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 { + ffi::luaL_traceback(state, thread, s, 0); + ffi::lua_remove(state, -2); + } + } +} + +// Initialize the error, panic, and destructed userdata metatables. +pub(crate) unsafe fn init_error_registry(state: *mut ffi::lua_State) -> Result<()> { + check_stack(state, 7)?; + + // Create error and panic metatables + + static ERROR_PRINT_BUFFER_KEY: u8 = 0; + + unsafe extern "C-unwind" fn error_tostring(state: *mut ffi::lua_State) -> c_int { + callback_error(state, |_| { + check_stack(state, 3)?; + + let err_buf = match get_internal_userdata::(state, -1, ptr::null()).as_ref() { + Some(WrappedFailure::Error(error)) => { + let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; + ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); + let err_buf = ffi::lua_touserdata(state, -1) as *mut String; + ffi::lua_pop(state, 2); + + (*err_buf).clear(); + // Depending on how the API is used and what error types scripts are given, it may + // be possible to make this consume arbitrary amounts of memory (for example, some + // kind of recursive error structure?) + let _ = write!(&mut (*err_buf), "{error}"); + Ok(err_buf) + } + Some(WrappedFailure::Panic(Some(panic))) => { + let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; + ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); + let err_buf = ffi::lua_touserdata(state, -1) as *mut String; + (*err_buf).clear(); + ffi::lua_pop(state, 2); + + if let Some(msg) = panic.downcast_ref::<&str>() { + let _ = write!(&mut (*err_buf), "{msg}"); + } else if let Some(msg) = panic.downcast_ref::() { + let _ = write!(&mut (*err_buf), "{msg}"); + } else { + let _ = write!(&mut (*err_buf), ""); + }; + Ok(err_buf) + } + Some(WrappedFailure::Panic(None)) => Err(Error::PreviouslyResumedPanic), + _ => { + // I'm not sure whether this is possible to trigger without bugs in mlua? + Err(Error::UserDataTypeMismatch) + } + }?; + + push_string(state, (*err_buf).as_bytes(), true)?; + (*err_buf).clear(); + + Ok(1) + }) + } + + init_internal_metatable::( + state, + Some(|state| { + ffi::lua_pushcfunction(state, error_tostring); + ffi::lua_setfield(state, -2, cstr!("__tostring")); + + // This is mostly for Luau typeof() function + ffi::lua_pushstring(state, cstr!("error")); + ffi::lua_setfield(state, -2, cstr!("__type")); + }), + )?; + + // Create destructed userdata metatable + + unsafe extern "C-unwind" fn destructed_error(state: *mut ffi::lua_State) -> c_int { + callback_error(state, |_| Err(Error::UserDataDestructed)) + } + + push_table(state, 0, 26, true)?; + ffi::lua_pushcfunction(state, destructed_error); + for &method in &[ + "__add", + "__sub", + "__mul", + "__div", + "__mod", + "__pow", + "__unm", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] + "__idiv", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__band", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__bor", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__bxor", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__bnot", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__shl", + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + "__shr", + "__concat", + "__len", + "__eq", + "__lt", + "__le", + "__index", + "__newindex", + "__call", + "__tostring", + #[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luajit52" + ))] + "__pairs", + #[cfg(any(feature = "lua53", feature = "lua52", feature = "luajit52"))] + "__ipairs", + #[cfg(feature = "luau")] + "__iter", + #[cfg(feature = "luau")] + "__namecall", + #[cfg(any(feature = "lua55", feature = "lua54"))] + "__close", + ] { + ffi::lua_pushvalue(state, -1); + rawset_field(state, -3, method)?; + } + ffi::lua_pop(state, 1); + + protect_lua!(state, 1, 0, fn(state) { + let destructed_mt_key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, destructed_mt_key); + })?; + + // Create error print buffer + init_internal_metatable::(state, None)?; + push_internal_userdata(state, String::new(), true)?; + protect_lua!(state, 1, 0, fn(state) { + let err_buf_key = &ERROR_PRINT_BUFFER_KEY as *const u8 as *const c_void; + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, err_buf_key); + })?; + + Ok(()) +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 00000000..0741f12c --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,362 @@ +use std::borrow::Cow; +use std::ffi::CStr; +use std::os::raw::{c_char, c_int, c_void}; +use std::{ptr, slice, str}; + +use crate::error::{Error, Result}; + +pub(crate) use error::{ + WrappedFailure, error_traceback, error_traceback_thread, init_error_registry, pop_error, + protect_lua_call, protect_lua_closure, +}; +pub(crate) use path::parse_path as parse_lookup_path; +pub(crate) use short_names::short_type_name; +pub(crate) use types::TypeKey; +pub(crate) use userdata::{ + DESTRUCTED_USERDATA_METATABLE, get_destructed_userdata_metatable, get_internal_metatable, + get_internal_userdata, get_userdata, init_internal_metatable, push_internal_userdata, push_userdata, + take_userdata, +}; + +#[cfg(not(feature = "luau"))] +pub(crate) use userdata::push_uninit_userdata; + +// Checks that Lua has enough free stack space for future stack operations. On failure, this will +// panic with an internal error message. +#[inline] +pub(crate) unsafe fn assert_stack(state: *mut ffi::lua_State, amount: c_int) { + // TODO: This should only be triggered when there is a logic error in `mlua`. In the future, + // when there is a way to be confident about stack safety and test it, this could be enabled + // only when `cfg!(debug_assertions)` is true. + mlua_assert!(ffi::lua_checkstack(state, amount) != 0, "out of stack space"); +} + +// Checks that Lua has enough free stack space and returns `Error::StackError` on failure. +#[inline] +pub(crate) unsafe fn check_stack(state: *mut ffi::lua_State, amount: c_int) -> Result<()> { + if ffi::lua_checkstack(state, amount) == 0 { + Err(Error::StackError) + } else { + Ok(()) + } +} + +pub(crate) struct StackGuard { + state: *mut ffi::lua_State, + top: c_int, +} + +impl StackGuard { + // Creates a StackGuard instance with record of the stack size, and on Drop will check the + // stack size and drop any extra elements. If the stack size at the end is *smaller* than at + // the beginning, this is considered a fatal logic error and will result in a panic. + #[inline] + pub(crate) unsafe fn new(state: *mut ffi::lua_State) -> StackGuard { + StackGuard { + state, + top: ffi::lua_gettop(state), + } + } + + // Same as `new()`, but allows specifying the expected stack size at the end of the scope. + #[inline] + pub(crate) fn with_top(state: *mut ffi::lua_State, top: c_int) -> StackGuard { + StackGuard { state, top } + } + + #[inline] + pub(crate) fn keep(&mut self, n: c_int) { + self.top += n; + } +} + +impl Drop for StackGuard { + #[track_caller] + fn drop(&mut self) { + unsafe { + let top = ffi::lua_gettop(self.state); + if top < self.top { + mlua_panic!("{} too many stack values popped", self.top - top) + } + if top > self.top { + ffi::lua_settop(self.state, self.top); + } + } + } +} + +// Uses 3 (or 1 if unprotected) stack spaces, does not call checkstack. +#[inline(always)] +pub(crate) unsafe fn push_string(state: *mut ffi::lua_State, s: &[u8], protect: bool) -> Result<()> { + // Always use protected mode if the string is too long + if protect || s.len() >= const { 1 << 30 } { + protect_lua!(state, 0, 1, |state| { + ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()); + }) + } else { + ffi::lua_pushlstring(state, s.as_ptr() as *const c_char, s.len()); + Ok(()) + } +} + +// Uses 3 (or 1 if unprotected) stack spaces, does not call checkstack. +#[cfg(feature = "lua55")] +pub(crate) unsafe fn push_external_string( + state: *mut ffi::lua_State, + mut bytes: Vec, + protect: bool, +) -> Result<()> { + bytes.push(0); + let s_len = bytes.len() - 1; // exclude null terminator + let s_ptr = bytes.as_ptr() as *const c_char; + let bytes_ud = Box::into_raw(Box::new(bytes)); + + unsafe extern "C" fn dealloc(ud: *mut c_void, _: *mut c_void, _: usize, _: usize) -> *mut c_void { + drop(Box::from_raw(ud as *mut Vec)); + ptr::null_mut() + } + + if protect { + let res = protect_lua!(state, 0, 1, move |state| { + ffi::lua_pushexternalstring(state, s_ptr, s_len, Some(dealloc), bytes_ud as *mut _); + }); + if res.is_err() { + // Deallocate on error + drop(Box::from_raw(bytes_ud)); + return res; + } + } else { + ffi::lua_pushexternalstring(state, s_ptr, s_len, Some(dealloc), bytes_ud as *mut _); + } + Ok(()) +} + +// Uses 3 stack spaces (when protect), does not call checkstack. +#[cfg(feature = "luau")] +#[inline(always)] +pub(crate) unsafe fn push_buffer(state: *mut ffi::lua_State, size: usize, protect: bool) -> Result<*mut u8> { + let data = if protect || size > const { 1024 * 1024 * 1024 } { + protect_lua!(state, 0, 1, |state| ffi::lua_newbuffer(state, size))? + } else { + ffi::lua_newbuffer(state, size) + }; + Ok(data as *mut u8) +} + +// Uses 3 stack spaces, does not call checkstack. +#[inline] +pub(crate) unsafe fn push_table( + state: *mut ffi::lua_State, + narr: usize, + nrec: usize, + protect: bool, +) -> Result<()> { + let narr: c_int = narr.try_into().unwrap_or(c_int::MAX); + let nrec: c_int = nrec.try_into().unwrap_or(c_int::MAX); + if protect || narr >= const { 1 << 26 } || nrec >= const { 1 << 26 } { + protect_lua!(state, 0, 1, |state| ffi::lua_createtable(state, narr, nrec)) + } else { + ffi::lua_createtable(state, narr, nrec); + Ok(()) + } +} + +// Uses 4 stack spaces, does not call checkstack. +pub(crate) unsafe fn rawget_field(state: *mut ffi::lua_State, table: c_int, field: &str) -> Result { + ffi::lua_pushvalue(state, table); + protect_lua!(state, 1, 1, |state| { + ffi::lua_pushlstring(state, field.as_ptr() as *const c_char, field.len()); + ffi::lua_rawget(state, -2) + }) +} + +// Uses 4 stack spaces, does not call checkstack. +pub(crate) unsafe fn rawset_field(state: *mut ffi::lua_State, table: c_int, field: &str) -> Result<()> { + ffi::lua_pushvalue(state, table); + protect_lua!(state, 2, 0, |state| { + ffi::lua_pushlstring(state, field.as_ptr() as *const c_char, field.len()); + ffi::lua_rotate(state, -3, 2); + ffi::lua_rawset(state, -3); + }) +} + +// A variant of `pcall` that does not allow Lua to catch Rust panics from `callback_error`. +pub(crate) unsafe extern "C-unwind" fn safe_pcall(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_checkstack(state, 2, ptr::null()); + + let top = ffi::lua_gettop(state); + if top == 0 { + ffi::lua_pushstring(state, cstr!("not enough arguments to pcall")); + ffi::lua_error(state); + } + + if ffi::lua_pcall(state, top - 1, ffi::LUA_MULTRET, 0) == ffi::LUA_OK { + ffi::lua_pushboolean(state, 1); + ffi::lua_insert(state, 1); + ffi::lua_gettop(state) + } else { + let wf_ud = get_internal_userdata::(state, -1, ptr::null()); + if let Some(WrappedFailure::Panic(_)) = wf_ud.as_ref() { + ffi::lua_error(state); + } + ffi::lua_pushboolean(state, 0); + ffi::lua_insert(state, -2); + 2 + } +} + +// A variant of `xpcall` that does not allow Lua to catch Rust panics from `callback_error`. +pub(crate) unsafe extern "C-unwind" fn safe_xpcall(state: *mut ffi::lua_State) -> c_int { + unsafe extern "C-unwind" fn xpcall_msgh(state: *mut ffi::lua_State) -> c_int { + ffi::luaL_checkstack(state, 2, ptr::null()); + + let wf_ud = get_internal_userdata::(state, -1, ptr::null()); + if let Some(WrappedFailure::Panic(_)) = wf_ud.as_ref() { + 1 + } else { + ffi::lua_pushvalue(state, ffi::lua_upvalueindex(1)); + ffi::lua_insert(state, 1); + ffi::lua_call(state, ffi::lua_gettop(state) - 1, ffi::LUA_MULTRET); + ffi::lua_gettop(state) + } + } + + ffi::luaL_checkstack(state, 2, ptr::null()); + + let top = ffi::lua_gettop(state); + if top < 2 { + ffi::lua_pushstring(state, cstr!("not enough arguments to xpcall")); + ffi::lua_error(state); + } + + ffi::lua_pushvalue(state, 2); + ffi::lua_pushcclosure(state, xpcall_msgh, 1); + ffi::lua_copy(state, 1, 2); + ffi::lua_replace(state, 1); + + if ffi::lua_pcall(state, ffi::lua_gettop(state) - 2, ffi::LUA_MULTRET, 1) == ffi::LUA_OK { + ffi::lua_pushboolean(state, 1); + ffi::lua_insert(state, 2); + ffi::lua_gettop(state) - 1 + } else { + let wf_ud = get_internal_userdata::(state, -1, ptr::null()); + if let Some(WrappedFailure::Panic(_)) = wf_ud.as_ref() { + ffi::lua_error(state); + } + ffi::lua_pushboolean(state, 0); + ffi::lua_insert(state, -2); + 2 + } +} + +// Returns Lua main thread for Lua >= 5.2 or checks that the passed thread is main for Lua 5.1. +// Does not call lua_checkstack, uses 1 stack space. +pub(crate) unsafe fn get_main_state(state: *mut ffi::lua_State) -> Option<*mut ffi::lua_State> { + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52"))] + { + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_MAINTHREAD); + let main_state = ffi::lua_tothread(state, -1); + ffi::lua_pop(state, 1); + Some(main_state) + } + #[cfg(any(feature = "lua51", feature = "luajit"))] + { + // Check the current state first + let is_main_state = ffi::lua_pushthread(state) == 1; + ffi::lua_pop(state, 1); + if is_main_state { Some(state) } else { None } + } + #[cfg(feature = "luau")] + Some(ffi::lua_mainthread(state)) +} + +// Converts the given lua value to a string in a reasonable format without causing a Lua error or +// panicking. +pub(crate) unsafe fn to_string(state: *mut ffi::lua_State, index: c_int) -> String { + match ffi::lua_type(state, index) { + ffi::LUA_TNONE => "".to_string(), + ffi::LUA_TNIL => "".to_string(), + ffi::LUA_TBOOLEAN => (ffi::lua_toboolean(state, index) != 1).to_string(), + ffi::LUA_TLIGHTUSERDATA => { + format!("", ffi::lua_topointer(state, index)) + } + ffi::LUA_TNUMBER => { + let mut isint = 0; + let i = ffi::lua_tointegerx(state, -1, &mut isint); + if isint == 0 { + ffi::lua_tonumber(state, index).to_string() + } else { + i.to_string() + } + } + #[cfg(feature = "luau")] + ffi::LUA_TVECTOR => { + let v = ffi::lua_tovector(state, index); + mlua_debug_assert!(!v.is_null(), "vector is null"); + let (x, y, z) = (*v, *v.add(1), *v.add(2)); + #[cfg(not(feature = "luau-vector4"))] + return format!("vector({x}, {y}, {z})"); + #[cfg(feature = "luau-vector4")] + return format!("vector({x}, {y}, {z}, {w})", w = *v.add(3)); + } + ffi::LUA_TSTRING => { + let mut size = 0; + // This will not trigger a 'm' error, because the reference is guaranteed to be of + // string type + let data = ffi::lua_tolstring(state, index, &mut size); + String::from_utf8_lossy(slice::from_raw_parts(data as *const u8, size)).into_owned() + } + ffi::LUA_TTABLE => format!("
", ffi::lua_topointer(state, index)), + ffi::LUA_TFUNCTION => format!("", ffi::lua_topointer(state, index)), + ffi::LUA_TUSERDATA => format!("", ffi::lua_topointer(state, index)), + ffi::LUA_TTHREAD => format!("", ffi::lua_topointer(state, index)), + #[cfg(feature = "luau")] + ffi::LUA_TBUFFER => format!("", ffi::lua_topointer(state, index)), + type_id => { + let type_name = CStr::from_ptr(ffi::lua_typename(state, type_id)).to_string_lossy(); + format!("<{type_name} {:?}>", ffi::lua_topointer(state, index)) + } + } +} + +#[inline(always)] +pub(crate) unsafe fn get_metatable_ptr(state: *mut ffi::lua_State, index: c_int) -> *const c_void { + #[cfg(feature = "luau")] + return ffi::lua_getmetatablepointer(state, index); + + #[cfg(not(feature = "luau"))] + if ffi::lua_getmetatable(state, index) == 0 { + ptr::null() + } else { + let p = ffi::lua_topointer(state, -1); + ffi::lua_pop(state, 1); + p + } +} + +pub(crate) unsafe fn ptr_to_str<'a>(input: *const c_char) -> Option<&'a str> { + if input.is_null() { + return None; + } + str::from_utf8(CStr::from_ptr(input).to_bytes()).ok() +} + +pub(crate) unsafe fn ptr_to_lossy_str<'a>(input: *const c_char) -> Option> { + if input.is_null() { + return None; + } + Some(String::from_utf8_lossy(CStr::from_ptr(input).to_bytes())) +} + +pub(crate) fn linenumber_to_usize(n: c_int) -> Option { + match n { + n if n < 0 => None, + n => Some(n as usize), + } +} + +mod error; +mod path; +mod short_names; +mod types; +mod userdata; diff --git a/src/util/path.rs b/src/util/path.rs new file mode 100644 index 00000000..b1381e12 --- /dev/null +++ b/src/util/path.rs @@ -0,0 +1,255 @@ +use std::borrow::Cow; +use std::fmt; +use std::iter::Peekable; +use std::str::CharIndices; + +use crate::error::{Error, Result}; +use crate::state::Lua; +use crate::traits::IntoLua; +use crate::types::Integer; +use crate::value::Value; + +#[derive(Debug)] +pub(crate) enum PathKey<'a> { + Str(Cow<'a, str>), + Int(Integer), +} + +impl fmt::Display for PathKey<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + PathKey::Str(s) => write!(f, "{}", s), + PathKey::Int(i) => write!(f, "{}", i), + } + } +} + +impl IntoLua for PathKey<'_> { + fn into_lua(self, lua: &Lua) -> Result { + match self { + PathKey::Str(s) => Ok(Value::String(lua.create_string(s.as_ref())?)), + PathKey::Int(i) => Ok(Value::Integer(i)), + } + } +} + +// Parses a path like `a.b[3]?.c["d"]` into segments of `(key, safe_nil)`. +pub(crate) fn parse_path<'a>(path: &'a str) -> Result, bool)>> { + fn read_ident<'a>(path: &'a str, chars: &mut Peekable>) -> (Cow<'a, str>, bool) { + let mut safe_nil = false; + let start = chars.peek().map(|&(i, _)| i).unwrap_or(path.len()); + let mut end = start; + while let Some(&(pos, c)) = chars.peek() { + if c == '.' || c == '?' || c.is_ascii_whitespace() || c == '[' { + if c == '?' { + safe_nil = true; + chars.next(); // consume '?' + } + break; + } + end = pos + c.len_utf8(); + chars.next(); + } + (Cow::Borrowed(&path[start..end]), safe_nil) + } + + let mut segments = Vec::new(); + let mut chars = path.char_indices().peekable(); + while let Some(&(pos, next)) = chars.peek() { + match next { + '.' => { + // Dot notation: identifier + chars.next(); + let (key, safe_nil) = read_ident(path, &mut chars); + if key.is_empty() { + return Err(Error::runtime(format!("empty key in path at position {pos}"))); + } + segments.push((PathKey::Str(key), safe_nil)); + } + '[' => { + // Bracket notation: either integer or quoted string + chars.next(); + let key = match chars.peek() { + Some(&(pos, c @ '0'..='9' | c @ '-')) => { + // Integer key + let negative = c == '-'; + if negative { + chars.next(); // consume '-' + } + let mut num: Option = None; + while let Some(&(_, c @ '0'..='9')) = chars.peek() { + let new_num = num + .unwrap_or(0) + .checked_mul(10) + .and_then(|n| n.checked_add((c as u8 - b'0') as Integer)) + .ok_or_else(|| { + Error::runtime(format!("integer overflow in path at position {pos}")) + })?; + num = Some(new_num); + chars.next(); // consume digit + } + match num { + Some(n) if negative => PathKey::Int(-n), + Some(n) => PathKey::Int(n), + None => { + let err = format!("invalid integer in path at position {pos}"); + return Err(Error::runtime(err)); + } + } + } + Some((_, '\'' | '"')) => { + // Quoted string + PathKey::Str(unquote_string(path, &mut chars)?) + } + Some((_, ']')) => { + return Err(Error::runtime(format!("empty key in path at position {pos}"))); + } + Some((pos, c)) => { + let err = format!("unexpected character '{c}' in path at position {pos}"); + return Err(Error::runtime(err)); + } + None => { + return Err(Error::runtime("unexpected end of path")); + } + }; + // Expect closing bracket + let mut safe_nil = false; + match chars.next() { + Some((_, ']')) => { + // Check for optional safe-nil operator + if let Some(&(_, '?')) = chars.peek() { + safe_nil = true; + chars.next(); // consume '?' + } + } + Some((pos, c)) => { + let err = format!("expected ']' in path at position {pos}, found '{c}'"); + return Err(Error::runtime(err)); + } + None => { + return Err(Error::runtime("unexpected end of path")); + } + } + segments.push((key, safe_nil)); + } + c if c.is_ascii_whitespace() => { + chars.next(); // Skip whitespace + } + _ if segments.is_empty() => { + // First segment without dot/bracket notation + let (key_cow, safe_nil) = read_ident(path, &mut chars); + if key_cow.is_empty() { + return Err(Error::runtime(format!("empty key in path at position {pos}"))); + } + segments.push((PathKey::Str(key_cow), safe_nil)); + } + c => { + let err = format!("unexpected character '{c}' in path at position {pos}"); + return Err(Error::runtime(err)); + } + } + } + Ok(segments) +} + +fn unquote_string<'a>(path: &'a str, chars: &mut Peekable>) -> Result> { + let (start_pos, first_quote) = chars.next().unwrap(); + let mut result = String::new(); + loop { + match chars.next() { + Some((pos, '\\')) => { + if result.is_empty() { + // First escape found, copy everything up to this point + result.push_str(&path[start_pos + 1..pos]); + } + match chars.next() { + Some((_, '\\')) => result.push('\\'), + Some((_, '"')) => result.push('"'), + Some((_, '\'')) => result.push('\''), + Some((_, other)) => { + result.push('\\'); + result.push(other); + } + None => continue, // will be handled by outer loop + } + } + Some((pos, c)) if c == first_quote => { + if !result.is_empty() { + return Ok(Cow::Owned(result)); + } + // No escapes, return borrowed slice + return Ok(Cow::Borrowed(&path[start_pos + 1..pos])); + } + Some((_, c)) => { + if !result.is_empty() { + result.push(c); + } + // If no escapes yet, continue tracking for potential borrowed slice + } + None => { + let err = format!("unexpected end of string at position {start_pos}"); + return Err(Error::runtime(err)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::{PathKey, parse_path}; + + #[test] + fn test_parse_path() { + // Test valid paths + let path = parse_path("a.b[3]?.c['d']").unwrap(); + assert_eq!(path.len(), 5); + assert!(matches!(path[0], (PathKey::Str(ref s), false) if s == "a")); + assert!(matches!(path[1], (PathKey::Str(ref s), false) if s == "b")); + assert!(matches!(path[2], (PathKey::Int(3), true))); + assert!(matches!(path[3], (PathKey::Str(ref s), false) if s == "c")); + assert!(matches!(path[4], (PathKey::Str(ref s), false) if s == "d")); + + // Test empty path + let path = parse_path("").unwrap(); + assert_eq!(path.len(), 0); + let path = parse_path(" ").unwrap(); + assert_eq!(path.len(), 0); + + // Test invalid dot syntax + let err = parse_path("a..b").unwrap_err().to_string(); + assert_eq!(err, "runtime error: empty key in path at position 1"); + let err = parse_path("a.b.").unwrap_err().to_string(); + assert_eq!(err, "runtime error: empty key in path at position 3"); + + // Test invalid bracket syntax + let err = parse_path("a[unclosed").unwrap_err().to_string(); + assert_eq!( + err, + "runtime error: unexpected character 'u' in path at position 2" + ); + let err = parse_path("a[]").unwrap_err().to_string(); + assert_eq!(err, "runtime error: empty key in path at position 1"); + let err = parse_path(r#"a["unclosed"#).unwrap_err().to_string(); + assert_eq!(err, "runtime error: unexpected end of string at position 2"); + let err = parse_path(r#"a["#).unwrap_err().to_string(); + assert_eq!(err, "runtime error: unexpected end of path"); + let err = parse_path(r#"a[123"#).unwrap_err().to_string(); + assert_eq!(err, "runtime error: unexpected end of path"); + let err = parse_path(r#"a['bla'123"#).unwrap_err().to_string(); + assert_eq!( + err, + "runtime error: expected ']' in path at position 7, found '1'" + ); + let err = parse_path(r#"a["bla"]x"#).unwrap_err().to_string(); + assert_eq!( + err, + "runtime error: unexpected character 'x' in path at position 8" + ); + + // Test bad integers + let err = parse_path("a[99999999999999999999]").unwrap_err().to_string(); + assert_eq!(err, "runtime error: integer overflow in path at position 2"); + let err = parse_path("a[-]").unwrap_err().to_string(); + assert_eq!(err, "runtime error: invalid integer in path at position 2"); + } +} diff --git a/src/util/short_names.rs b/src/util/short_names.rs new file mode 100644 index 00000000..0fae0246 --- /dev/null +++ b/src/util/short_names.rs @@ -0,0 +1,87 @@ +//! Inspired by bevy's [disqualified] +//! +//! [disqualified]: https://github.com/bevyengine/disqualified/blob/main/src/short_name.rs + +use std::any::type_name; + +/// Returns a short version of a type name `T` without all module paths. +/// +/// The short name of a type is its full name as returned by +/// [`std::any::type_name`], but with the prefix of all paths removed. For +/// example, the short name of `alloc::vec::Vec>` +/// would be `Vec>`. +pub(crate) fn short_type_name() -> String { + let full_name = type_name::(); + + // Generics result in nested paths within <..> blocks. + // Consider "core::option::Option". + // To tackle this, we parse the string from left to right, collapsing as we go. + let mut index: usize = 0; + let end_of_string = full_name.len(); + let mut parsed_name = String::new(); + + while index < end_of_string { + let rest_of_string = full_name.get(index..end_of_string).unwrap_or_default(); + + // Collapse everything up to the next special character, then skip over it + if let Some(special_character_index) = + rest_of_string.find(|c: char| [' ', '<', '>', '(', ')', '[', ']', ',', ';'].contains(&c)) + { + let segment_to_collapse = rest_of_string.get(0..special_character_index).unwrap_or_default(); + parsed_name += collapse_type_name(segment_to_collapse); + // Insert the special character + let special_character = &rest_of_string[special_character_index..=special_character_index]; + parsed_name += special_character; + + // Remove lifetimes like <'_> or <'_, '_, ...> + if parsed_name.ends_with("<'_>") || parsed_name.ends_with("<'_, ") { + _ = parsed_name.split_off(parsed_name.len() - 4); + } + + match special_character { + ">" | ")" | "]" if rest_of_string[special_character_index + 1..].starts_with("::") => { + parsed_name += "::"; + // Move the index past the "::" + index += special_character_index + 3; + } + // Move the index just past the special character + _ => index += special_character_index + 1, + } + } else { + // If there are no special characters left, we're done! + parsed_name += collapse_type_name(rest_of_string); + index = end_of_string; + } + } + parsed_name +} + +#[inline(always)] +fn collapse_type_name(segment: &str) -> &str { + segment.rsplit("::").next().unwrap() +} + +#[cfg(test)] +mod tests { + use super::short_type_name; + use std::collections::HashMap; + use std::marker::PhantomData; + + struct MyData<'a, 'b>(PhantomData<&'a &'b ()>); + struct MyDataT<'a, T>(PhantomData<&'a T>); + + #[test] + fn tests() { + assert_eq!(short_type_name::(), "String"); + assert_eq!(short_type_name::>(), "Option"); + assert_eq!(short_type_name::<(String, &str)>(), "(String, &str)"); + assert_eq!(short_type_name::<[i32; 3]>(), "[i32; 3]"); + assert_eq!( + short_type_name::>>(), + "HashMap>" + ); + assert_eq!(short_type_name:: i32>(), "dyn Fn(i32) -> i32"); + assert_eq!(short_type_name::>(), "MyDataT<&str>"); + assert_eq!(short_type_name::<(&MyData, [MyData])>(), "(MyData, [MyData])"); + } +} diff --git a/src/util/types.rs b/src/util/types.rs new file mode 100644 index 00000000..8bc9d8b2 --- /dev/null +++ b/src/util/types.rs @@ -0,0 +1,80 @@ +use std::any::Any; +use std::os::raw::c_void; + +use crate::types::{Callback, CallbackUpvalue}; + +#[cfg(feature = "async")] +use crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue}; + +pub(crate) trait TypeKey: Any { + fn type_key() -> *const c_void; +} + +impl TypeKey for String { + #[inline(always)] + fn type_key() -> *const c_void { + static STRING_TYPE_KEY: u8 = 0; + &STRING_TYPE_KEY as *const u8 as *const c_void + } +} + +impl TypeKey for Callback { + #[inline(always)] + fn type_key() -> *const c_void { + static CALLBACK_TYPE_KEY: u8 = 0; + &CALLBACK_TYPE_KEY as *const u8 as *const c_void + } +} + +impl TypeKey for CallbackUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static CALLBACK_UPVALUE_TYPE_KEY: u8 = 0; + &CALLBACK_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + +#[cfg(not(feature = "luau"))] +impl TypeKey for crate::types::HookCallback { + #[inline(always)] + fn type_key() -> *const c_void { + static HOOK_CALLBACK_TYPE_KEY: u8 = 0; + &HOOK_CALLBACK_TYPE_KEY as *const u8 as *const c_void + } +} + +#[cfg(feature = "async")] +impl TypeKey for AsyncCallback { + #[inline(always)] + fn type_key() -> *const c_void { + static ASYNC_CALLBACK_TYPE_KEY: u8 = 0; + &ASYNC_CALLBACK_TYPE_KEY as *const u8 as *const c_void + } +} + +#[cfg(feature = "async")] +impl TypeKey for AsyncCallbackUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static ASYNC_CALLBACK_UPVALUE_TYPE_KEY: u8 = 0; + &ASYNC_CALLBACK_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + +#[cfg(feature = "async")] +impl TypeKey for AsyncPollUpvalue { + #[inline(always)] + fn type_key() -> *const c_void { + static ASYNC_POLL_UPVALUE_TYPE_KEY: u8 = 0; + &ASYNC_POLL_UPVALUE_TYPE_KEY as *const u8 as *const c_void + } +} + +#[cfg(feature = "async")] +impl TypeKey for Option { + #[inline(always)] + fn type_key() -> *const c_void { + static WAKER_TYPE_KEY: u8 = 0; + &WAKER_TYPE_KEY as *const u8 as *const c_void + } +} diff --git a/src/util/userdata.rs b/src/util/userdata.rs new file mode 100644 index 00000000..76a9507c --- /dev/null +++ b/src/util/userdata.rs @@ -0,0 +1,173 @@ +use std::os::raw::{c_int, c_void}; +use std::{mem, ptr}; + +use crate::error::Result; +use crate::userdata::collect_userdata; +use crate::util::{TypeKey, check_stack, get_metatable_ptr, push_table, rawset_field}; + +// Pushes the userdata and attaches a metatable with __gc method. +// Internally uses 3 stack spaces, does not call checkstack. +pub(crate) unsafe fn push_internal_userdata( + state: *mut ffi::lua_State, + t: T, + protect: bool, +) -> Result<*mut T> { + #[cfg(not(feature = "luau"))] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| { + let ud_ptr = ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T; + ptr::write(ud_ptr, t); + ud_ptr + })? + } else { + let ud_ptr = ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T; + ptr::write(ud_ptr, t); + ud_ptr + }; + + #[cfg(feature = "luau")] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| ffi::lua_newuserdata_t::(state, t))? + } else { + ffi::lua_newuserdata_t::(state, t) + }; + + get_internal_metatable::(state); + ffi::lua_setmetatable(state, -2); + Ok(ud_ptr) +} + +#[track_caller] +pub(crate) unsafe fn get_internal_metatable(state: *mut ffi::lua_State) { + ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, T::type_key()); + debug_assert!(ffi::lua_isnil(state, -1) == 0, "internal metatable not found"); +} + +// Initialize the internal metatable for a type T (with __gc method). +// Uses 6 stack spaces and calls checkstack. +pub(crate) unsafe fn init_internal_metatable( + state: *mut ffi::lua_State, + customize_fn: Option, +) -> Result<()> { + check_stack(state, 6)?; + + push_table(state, 0, 3, true)?; + + #[cfg(not(feature = "luau"))] + { + ffi::lua_pushcfunction(state, collect_userdata::); + rawset_field(state, -2, "__gc")?; + } + + ffi::lua_pushboolean(state, 0); + rawset_field(state, -2, "__metatable")?; + + protect_lua!(state, 1, 0, |state| { + if let Some(f) = customize_fn { + f(state); + } + + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, T::type_key()); + })?; + + Ok(()) +} + +// Uses up to 1 stack space, does not call `checkstack` +pub(crate) unsafe fn get_internal_userdata( + state: *mut ffi::lua_State, + index: c_int, + mut type_mt_ptr: *const c_void, +) -> *mut T { + let ud = ffi::lua_touserdata(state, index) as *mut T; + if ud.is_null() { + return ptr::null_mut(); + } + let mt_ptr = get_metatable_ptr(state, index); + if type_mt_ptr.is_null() { + get_internal_metatable::(state); + type_mt_ptr = ffi::lua_topointer(state, -1); + ffi::lua_pop(state, 1); + } + if mt_ptr != type_mt_ptr { + return ptr::null_mut(); + } + ud +} + +// Internally uses 3 stack spaces, does not call checkstack. +#[inline] +#[cfg(not(feature = "luau"))] +pub(crate) unsafe fn push_uninit_userdata(state: *mut ffi::lua_State, protect: bool) -> Result<*mut T> { + if protect { + protect_lua!(state, 0, 1, |state| { + ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T + }) + } else { + Ok(ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T) + } +} + +// Internally uses 3 stack spaces, does not call checkstack. +#[inline] +pub(crate) unsafe fn push_userdata(state: *mut ffi::lua_State, t: T, protect: bool) -> Result<*mut T> { + let size = const { mem::size_of::() }; + + #[cfg(not(feature = "luau"))] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| ffi::lua_newuserdata(state, size))? + } else { + ffi::lua_newuserdata(state, size) + } as *mut T; + + #[cfg(feature = "luau")] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, |state| { + ffi::lua_newuserdatadtor(state, size, collect_userdata::) + })? + } else { + ffi::lua_newuserdatadtor(state, size, collect_userdata::) + } as *mut T; + + ptr::write(ud_ptr, t); + Ok(ud_ptr) +} + +#[inline] +#[track_caller] +pub(crate) unsafe fn get_userdata(state: *mut ffi::lua_State, index: c_int) -> *mut T { + let ud = ffi::lua_touserdata(state, index) as *mut T; + mlua_debug_assert!(!ud.is_null(), "userdata pointer is null"); + ud +} + +/// Unwraps `T` from the Lua userdata and invalidating it by setting the special "destructed" +/// metatable. +/// +/// This method does not check that userdata is of type `T` and was not previously invalidated. +/// +/// Uses 1 extra stack space, does not call checkstack. +pub(crate) unsafe fn take_userdata(state: *mut ffi::lua_State, idx: c_int) -> T { + #[rustfmt::skip] + let idx = if idx < 0 { ffi::lua_absindex(state, idx) } else { idx }; + + // Update the metatable of this userdata to a special one with no `__gc` method and with + // metamethods that trigger an error on access. + // We do this so that it will not be double dropped or used after being dropped. + get_destructed_userdata_metatable(state); + ffi::lua_setmetatable(state, idx); + let ud = get_userdata::(state, idx); + + // Update userdata tag to disable destructor and mark as destructed + #[cfg(feature = "luau")] + ffi::lua_setuserdatatag(state, idx, 1); + + ptr::read(ud) +} + +pub(crate) unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_State) { + let key = &DESTRUCTED_USERDATA_METATABLE as *const u8 as *const c_void; + ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, key); +} + +pub(crate) static DESTRUCTED_USERDATA_METATABLE: u8 = 0; diff --git a/src/value.rs b/src/value.rs index 96f36fcb..b4891e73 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,30 +1,36 @@ -use std::iter::{self, FromIterator}; +use std::cmp::Ordering; +use std::collections::HashSet; use std::os::raw::c_void; -use std::{ptr, slice, str, vec}; +use std::{fmt, ptr, str}; -#[cfg(feature = "serialize")] -use { - serde::ser::{self, Serialize, Serializer}, - std::convert::TryInto, - std::result::Result as StdResult, -}; +use num_traits::FromPrimitive; use crate::error::{Error, Result}; -use crate::ffi; use crate::function::Function; -use crate::lua::Lua; -use crate::string::String; +use crate::string::{BorrowedStr, LuaString}; use crate::table::Table; use crate::thread::Thread; -use crate::types::{Integer, LightUserData, Number}; +use crate::types::{Integer, LightUserData, Number, ValueRef}; use crate::userdata::AnyUserData; +use crate::util::{StackGuard, check_stack}; + +#[cfg(feature = "serde")] +use { + crate::table::SerializableTable, + rustc_hash::FxHashSet, + serde::ser::{self, Serialize, Serializer}, + std::{cell::RefCell, rc::Rc, result::Result as StdResult}, +}; -/// A dynamically typed Lua value. The `String`, `Table`, `Function`, `Thread`, and `UserData` -/// variants contain handle types into the internal Lua state. It is a logic error to mix handle -/// types between separate `Lua` instances, and doing so will result in a panic. -#[derive(Debug, Clone)] -pub enum Value<'lua> { +/// A dynamically typed Lua value. +/// +/// The non-primitive variants (eg. string/table/function/thread/userdata) contain handle types +/// into the internal Lua state. It is a logic error to mix handle types between separate +/// `Lua` instances, and doing so will result in a panic. +#[derive(Clone, Default)] +pub enum Value { /// The Lua value `nil`. + #[default] Nil, /// The Lua value `true` or `false`. Boolean(bool), @@ -39,28 +45,41 @@ pub enum Value<'lua> { /// A Luau vector. #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - Vector(f32, f32, f32), + Vector(crate::Vector), /// An interned string, managed by Lua. /// /// Unlike Rust strings, Lua strings may not be valid UTF-8. - String(String<'lua>), + String(LuaString), /// Reference to a Lua table. - Table(Table<'lua>), + Table(Table), /// Reference to a Lua function (or closure). - Function(Function<'lua>), + Function(Function), /// Reference to a Lua thread (or coroutine). - Thread(Thread<'lua>), + Thread(Thread), /// Reference to a userdata object that holds a custom type which implements `UserData`. + /// /// Special builtin userdata types will be represented as other `Value` variants. - UserData(AnyUserData<'lua>), + UserData(AnyUserData), + /// A Luau buffer. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + Buffer(crate::Buffer), /// `Error` is a special builtin userdata type. When received from Lua it is implicitly cloned. - Error(Error), + Error(Box), + /// Any other value not known to mlua (eg. LuaJIT CData). + Other(#[doc(hidden)] ValueRef), } pub use self::Value::Nil; -impl<'lua> Value<'lua> { - pub const fn type_name(&self) -> &'static str { +impl Value { + /// A special value (lightuserdata) to represent null value. + /// + /// It can be used in Lua tables without downsides of `nil`. + pub const NULL: Value = Value::LightUserData(LightUserData(ptr::null_mut())); + + /// Returns type name of this value. + pub fn type_name(&self) -> &'static str { match *self { Value::Nil => "nil", Value::Boolean(_) => "boolean", @@ -68,266 +87,677 @@ impl<'lua> Value<'lua> { Value::Integer(_) => "integer", Value::Number(_) => "number", #[cfg(feature = "luau")] - Value::Vector(_, _, _) => "vector", + Value::Vector(_) => "vector", Value::String(_) => "string", Value::Table(_) => "table", Value::Function(_) => "function", Value::Thread(_) => "thread", Value::UserData(_) => "userdata", + #[cfg(feature = "luau")] + Value::Buffer(_) => "buffer", Value::Error(_) => "error", + Value::Other(_) => "other", } } /// Compares two values for equality. /// /// Equality comparisons do not convert strings to numbers or vice versa. - /// Tables, Functions, Threads, and Userdata are compared by reference: + /// Tables, functions, threads, and userdata are compared by reference: /// two objects are considered equal only if they are the same object. /// - /// If Tables or Userdata have `__eq` metamethod then mlua will try to invoke it. + /// If table or userdata have `__eq` metamethod then mlua will try to invoke it. /// The first value is checked first. If that value does not define a metamethod /// for `__eq`, then mlua will check the second value. /// Then mlua calls the metamethod with the two values as arguments, if found. - pub fn equals>(&self, other: T) -> Result { - match (self, other.as_ref()) { + pub fn equals(&self, other: &Self) -> Result { + match (self, other) { (Value::Table(a), Value::Table(b)) => a.equals(b), (Value::UserData(a), Value::UserData(b)) => a.equals(b), - _ => Ok(self == other.as_ref()), + (a, b) => Ok(a == b), } } /// Converts the value to a generic C pointer. /// - /// The value can be a userdata, a table, a thread, a string, or a function; otherwise it returns NULL. - /// Different objects will give different pointers. + /// The value can be a userdata, a table, a thread, a string, or a function; otherwise it + /// returns NULL. Different objects will give different pointers. /// There is no way to convert the pointer back to its original value. /// /// Typically this function is used only for hashing and debug information. + #[inline] pub fn to_pointer(&self) -> *const c_void { - unsafe { - match self { - Value::LightUserData(ud) => ud.0, - Value::String(String(v)) - | Value::Table(Table(v)) - | Value::Function(Function(v)) - | Value::Thread(Thread(v)) - | Value::UserData(AnyUserData(v)) => v - .lua - .ref_thread_exec(|refthr| ffi::lua_topointer(refthr, v.index)), - _ => ptr::null(), - } + match self { + Value::LightUserData(ud) => ud.0, + Value::Table(Table(vref)) + | Value::Function(Function(vref)) + | Value::Thread(Thread(vref, ..)) + | Value::UserData(AnyUserData(vref)) + | Value::Other(vref) => vref.to_pointer(), + Value::String(s) => s.to_pointer(), + #[cfg(feature = "luau")] + Value::Buffer(crate::Buffer(vref)) => vref.to_pointer(), + _ => ptr::null(), } } -} -impl<'lua> PartialEq for Value<'lua> { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Value::Nil, Value::Nil) => true, - (Value::Boolean(a), Value::Boolean(b)) => a == b, - (Value::LightUserData(a), Value::LightUserData(b)) => a == b, - (Value::Integer(a), Value::Integer(b)) => *a == *b, - (Value::Integer(a), Value::Number(b)) => *a as Number == *b, - (Value::Number(a), Value::Integer(b)) => *a == *b as Number, - (Value::Number(a), Value::Number(b)) => *a == *b, + /// Converts the value to a string. + /// + /// This might invoke the `__tostring` metamethod for non-primitive types (eg. tables, + /// functions). + pub fn to_string(&self) -> Result { + unsafe fn invoke_tostring(vref: &ValueRef) -> Result { + let lua = vref.lua.lock(); + let state = lua.state(); + let _guard = StackGuard::new(state); + check_stack(state, 3)?; + + lua.push_ref(vref); + protect_lua!(state, 1, 1, fn(state) { + ffi::luaL_tolstring(state, -1, ptr::null_mut()); + })?; + Ok(LuaString(lua.pop_ref()).to_str()?.to_string()) + } + + match self { + Value::Nil => Ok("nil".to_string()), + Value::Boolean(b) => Ok(b.to_string()), + Value::LightUserData(ud) if ud.0.is_null() => Ok("null".to_string()), + Value::LightUserData(ud) => Ok(format!("lightuserdata: {:p}", ud.0)), + Value::Integer(i) => Ok(i.to_string()), + Value::Number(n) => Ok(n.to_string()), #[cfg(feature = "luau")] - (Value::Vector(x1, y1, z1), Value::Vector(x2, y2, z2)) => (x1, y1, z1) == (x2, y2, z2), - (Value::String(a), Value::String(b)) => a == b, - (Value::Table(a), Value::Table(b)) => a == b, - (Value::Function(a), Value::Function(b)) => a == b, - (Value::Thread(a), Value::Thread(b)) => a == b, - (Value::UserData(a), Value::UserData(b)) => a == b, - _ => false, + Value::Vector(v) => Ok(v.to_string()), + Value::String(s) => Ok(s.to_str()?.to_string()), + Value::Table(Table(vref)) + | Value::Function(Function(vref)) + | Value::Thread(Thread(vref, ..)) + | Value::UserData(AnyUserData(vref)) + | Value::Other(vref) => unsafe { invoke_tostring(vref) }, + #[cfg(feature = "luau")] + Value::Buffer(crate::Buffer(vref)) => unsafe { invoke_tostring(vref) }, + Value::Error(err) => Ok(err.to_string()), } } -} -impl<'lua> AsRef> for Value<'lua> { + /// Returns `true` if the value is a [`Nil`]. #[inline] - fn as_ref(&self) -> &Self { - self + pub fn is_nil(&self) -> bool { + self == &Nil } -} -#[cfg(feature = "serialize")] -impl<'lua> Serialize for Value<'lua> { - fn serialize(&self, serializer: S) -> StdResult - where - S: Serializer, - { - match self { - Value::Nil => serializer.serialize_unit(), - Value::Boolean(b) => serializer.serialize_bool(*b), - #[allow(clippy::useless_conversion)] - Value::Integer(i) => serializer - .serialize_i64((*i).try_into().expect("cannot convert lua_Integer to i64")), - #[allow(clippy::useless_conversion)] - Value::Number(n) => serializer.serialize_f64(*n), - #[cfg(feature = "luau")] - Value::Vector(x, y, z) => (x, y, z).serialize(serializer), - Value::String(s) => s.serialize(serializer), - Value::Table(t) => t.serialize(serializer), - Value::UserData(ud) => ud.serialize(serializer), - Value::LightUserData(ud) if ud.0.is_null() => serializer.serialize_none(), - Value::Error(_) | Value::LightUserData(_) | Value::Function(_) | Value::Thread(_) => { - let msg = format!("cannot serialize <{}>", self.type_name()); - Err(ser::Error::custom(msg)) - } + /// Returns `true` if the value is a [`NULL`]. + /// + /// [`NULL`]: Value::NULL + #[inline] + pub fn is_null(&self) -> bool { + self == &Self::NULL + } + + /// Returns `true` if the value is a boolean. + #[inline] + pub fn is_boolean(&self) -> bool { + self.as_boolean().is_some() + } + + /// Cast the value to boolean. + /// + /// If the value is a Boolean, returns it or `None` otherwise. + #[inline] + pub fn as_boolean(&self) -> Option { + match *self { + Value::Boolean(b) => Some(b), + _ => None, } } -} -/// Trait for types convertible to `Value`. -pub trait ToLua<'lua> { - /// Performs the conversion. - fn to_lua(self, lua: &'lua Lua) -> Result>; -} + /// Returns `true` if the value is a [`LightUserData`]. + #[inline] + pub fn is_light_userdata(&self) -> bool { + self.as_light_userdata().is_some() + } -/// Trait for types convertible from `Value`. -pub trait FromLua<'lua>: Sized { - /// Performs the conversion. - fn from_lua(lua_value: Value<'lua>, lua: &'lua Lua) -> Result; -} + /// Cast the value to [`LightUserData`]. + /// + /// If the value is a [`LightUserData`], returns it or `None` otherwise. + #[inline] + pub fn as_light_userdata(&self) -> Option { + match *self { + Value::LightUserData(l) => Some(l), + _ => None, + } + } -/// Multiple Lua values used for both argument passing and also for multiple return values. -#[derive(Debug, Clone)] -pub struct MultiValue<'lua>(Vec>); + /// Returns `true` if the value is an [`Integer`]. + #[inline] + pub fn is_integer(&self) -> bool { + self.as_integer().is_some() + } -impl<'lua> MultiValue<'lua> { - /// Creates an empty `MultiValue` containing no values. + /// Cast the value to [`Integer`]. + /// + /// If the value is a Lua [`Integer`], returns it or `None` otherwise. #[inline] - pub fn new() -> MultiValue<'lua> { - MultiValue(Vec::new()) + pub fn as_integer(&self) -> Option { + match *self { + Value::Integer(i) => Some(i), + _ => None, + } } - /// Similar to `new` but can return previously used container with allocated capacity. + /// Cast the value to `i32`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `i32` or return `None` otherwise. #[inline] - pub(crate) fn new_or_cached(lua: &'lua Lua) -> MultiValue<'lua> { - lua.new_or_cached_multivalue() + pub fn as_i32(&self) -> Option { + #[allow(clippy::useless_conversion)] + self.as_integer().and_then(|i| i32::try_from(i).ok()) } -} -impl<'lua> Default for MultiValue<'lua> { + /// Cast the value to `u32`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `u32` or return `None` otherwise. #[inline] - fn default() -> MultiValue<'lua> { - MultiValue::new() + pub fn as_u32(&self) -> Option { + self.as_integer().and_then(|i| u32::try_from(i).ok()) } -} -impl<'lua> FromIterator> for MultiValue<'lua> { + /// Cast the value to `i64`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `i64` or return `None` otherwise. #[inline] - fn from_iter>>(iter: I) -> Self { - MultiValue::from_vec(Vec::from_iter(iter)) + pub fn as_i64(&self) -> Option { + #[cfg(target_pointer_width = "64")] + return self.as_integer(); + #[cfg(not(target_pointer_width = "64"))] + return self.as_integer().map(i64::from); } -} -impl<'lua> IntoIterator for MultiValue<'lua> { - type Item = Value<'lua>; - type IntoIter = iter::Rev>>; + /// Cast the value to `u64`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `u64` or return `None` otherwise. + #[inline] + pub fn as_u64(&self) -> Option { + self.as_integer().and_then(|i| u64::try_from(i).ok()) + } + /// Cast the value to `isize`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `isize` or return `None` otherwise. #[inline] - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter().rev() + pub fn as_isize(&self) -> Option { + self.as_integer().and_then(|i| isize::try_from(i).ok()) } -} -impl<'a, 'lua> IntoIterator for &'a MultiValue<'lua> { - type Item = &'a Value<'lua>; - type IntoIter = iter::Rev>>; + /// Cast the value to `usize`. + /// + /// If the value is a Lua [`Integer`], try to convert it to `usize` or return `None` otherwise. + #[inline] + pub fn as_usize(&self) -> Option { + self.as_integer().and_then(|i| usize::try_from(i).ok()) + } + /// Returns `true` if the value is a Lua [`Number`]. #[inline] - fn into_iter(self) -> Self::IntoIter { - (&self.0).iter().rev() + pub fn is_number(&self) -> bool { + self.as_number().is_some() + } + + /// Cast the value to [`Number`]. + /// + /// If the value is a Lua [`Number`], returns it or `None` otherwise. + #[inline] + pub fn as_number(&self) -> Option { + match *self { + Value::Number(n) => Some(n), + _ => None, + } + } + + /// Cast the value to `f32`. + /// + /// If the value is a Lua [`Number`], try to convert it to `f32` or return `None` otherwise. + #[inline] + pub fn as_f32(&self) -> Option { + self.as_number().and_then(f32::from_f64) + } + + /// Cast the value to `f64`. + /// + /// If the value is a Lua [`Number`], try to convert it to `f64` or return `None` otherwise. + #[inline] + pub fn as_f64(&self) -> Option { + self.as_number() + } + + /// Returns `true` if the value is a [`LuaString`]. + #[inline] + pub fn is_string(&self) -> bool { + self.as_string().is_some() + } + + /// Cast the value to a [`LuaString`]. + /// + /// If the value is a [`LuaString`], returns it or `None` otherwise. + #[inline] + pub fn as_string(&self) -> Option<&LuaString> { + match self { + Value::String(s) => Some(s), + _ => None, + } + } + + /// Cast the value to [`BorrowedStr`]. + /// + /// If the value is a [`LuaString`], try to convert it to [`BorrowedStr`] or return `None` + /// otherwise. + #[deprecated( + since = "0.11.0", + note = "This method does not follow Rust naming convention. Use `as_string().and_then(|s| s.to_str().ok())` instead." + )] + #[inline] + pub fn as_str(&self) -> Option { + self.as_string().and_then(|s| s.to_str().ok()) + } + + /// Cast the value to [`String`]. + /// + /// If the value is a [`LuaString`], converts it to [`String`] or returns `None` otherwise. + #[deprecated( + since = "0.11.0", + note = "This method does not follow Rust naming convention. Use `as_string().map(|s| s.to_string_lossy())` instead." + )] + #[inline] + pub fn as_string_lossy(&self) -> Option { + self.as_string().map(|s| s.to_string_lossy()) } -} -impl<'lua> MultiValue<'lua> { + /// Returns `true` if the value is a Lua [`Table`]. #[inline] - pub fn from_vec(mut v: Vec>) -> MultiValue<'lua> { - v.reverse(); - MultiValue(v) + pub fn is_table(&self) -> bool { + self.as_table().is_some() } + /// Cast the value to [`Table`]. + /// + /// If the value is a Lua [`Table`], returns it or `None` otherwise. #[inline] - pub fn into_vec(self) -> Vec> { - let mut v = self.0; - v.reverse(); - v + pub fn as_table(&self) -> Option<&Table> { + match self { + Value::Table(t) => Some(t), + _ => None, + } } + /// Returns `true` if the value is a Lua [`Thread`]. #[inline] - pub(crate) fn reserve(&mut self, size: usize) { - self.0.reserve(size); + pub fn is_thread(&self) -> bool { + self.as_thread().is_some() } + /// Cast the value to [`Thread`]. + /// + /// If the value is a Lua [`Thread`], returns it or `None` otherwise. #[inline] - pub(crate) fn push_front(&mut self, value: Value<'lua>) { - self.0.push(value); + pub fn as_thread(&self) -> Option<&Thread> { + match self { + Value::Thread(t) => Some(t), + _ => None, + } } + /// Returns `true` if the value is a Lua [`Function`]. #[inline] - pub(crate) fn pop_front(&mut self) -> Option> { - self.0.pop() + pub fn is_function(&self) -> bool { + self.as_function().is_some() } + /// Cast the value to [`Function`]. + /// + /// If the value is a Lua [`Function`], returns it or `None` otherwise. #[inline] - pub fn clear(&mut self) { - self.0.clear(); + pub fn as_function(&self) -> Option<&Function> { + match self { + Value::Function(f) => Some(f), + _ => None, + } } + /// Returns `true` if the value is an [`AnyUserData`]. #[inline] - pub fn len(&self) -> usize { - self.0.len() + pub fn is_userdata(&self) -> bool { + self.as_userdata().is_some() } + /// Cast the value to [`AnyUserData`]. + /// + /// If the value is an [`AnyUserData`], returns it or `None` otherwise. #[inline] - pub fn is_empty(&self) -> bool { - self.0.is_empty() + pub fn as_userdata(&self) -> Option<&AnyUserData> { + match self { + Value::UserData(ud) => Some(ud), + _ => None, + } } + /// Cast the value to a [`Buffer`]. + /// + /// If the value is [`Buffer`], returns it or `None` otherwise. + /// + /// [`Buffer`]: crate::Buffer + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] #[inline] - pub fn iter(&self) -> iter::Rev>> { - self.0.iter().rev() + pub fn as_buffer(&self) -> Option<&crate::Buffer> { + match self { + Value::Buffer(b) => Some(b), + _ => None, + } } + /// Returns `true` if the value is a [`Buffer`]. + /// + /// [`Buffer`]: crate::Buffer + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] #[inline] - pub(crate) fn drain_all(&mut self) -> iter::Rev>> { - self.0.drain(..).rev() + pub fn is_buffer(&self) -> bool { + self.as_buffer().is_some() } + /// Returns `true` if the value is an [`Error`]. #[inline] - pub(crate) fn refill( - &mut self, - iter: impl IntoIterator>>, - ) -> Result<()> { - self.0.clear(); - for value in iter { - self.0.push(value?); + pub fn is_error(&self) -> bool { + self.as_error().is_some() + } + + /// Cast the value to [`Error`]. + /// + /// If the value is an [`Error`], returns it or `None` otherwise. + pub fn as_error(&self) -> Option<&Error> { + match self { + Value::Error(e) => Some(e), + _ => None, + } + } + + /// Wrap reference to this Value into [`SerializableValue`]. + /// + /// This allows customizing serialization behavior using serde. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + pub fn to_serializable(&self) -> SerializableValue<'_> { + SerializableValue::new(self, Default::default(), None) + } + + // Compares two values. + // Used to sort values for Debug printing. + pub(crate) fn sort_cmp(&self, other: &Self) -> Ordering { + fn cmp_num(a: Number, b: Number) -> Ordering { + match (a, b) { + _ if a < b => Ordering::Less, + _ if a > b => Ordering::Greater, + _ => Ordering::Equal, + } + } + + match (self, other) { + // Nil + (Value::Nil, Value::Nil) => Ordering::Equal, + (Value::Nil, _) => Ordering::Less, + (_, Value::Nil) => Ordering::Greater, + // Null (a special case) + (Value::LightUserData(ud1), Value::LightUserData(ud2)) if ud1 == ud2 => Ordering::Equal, + (Value::LightUserData(ud1), _) if ud1.0.is_null() => Ordering::Less, + (_, Value::LightUserData(ud2)) if ud2.0.is_null() => Ordering::Greater, + // Boolean + (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b), + (Value::Boolean(_), _) => Ordering::Less, + (_, Value::Boolean(_)) => Ordering::Greater, + // Integer && Number + (Value::Integer(a), Value::Integer(b)) => a.cmp(b), + (Value::Integer(a), Value::Number(b)) => cmp_num(*a as Number, *b), + (Value::Number(a), Value::Integer(b)) => cmp_num(*a, *b as Number), + (Value::Number(a), Value::Number(b)) => cmp_num(*a, *b), + (Value::Integer(_) | Value::Number(_), _) => Ordering::Less, + (_, Value::Integer(_) | Value::Number(_)) => Ordering::Greater, + // Vector (Luau) + #[cfg(feature = "luau")] + (Value::Vector(a), Value::Vector(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal), + // String + (Value::String(a), Value::String(b)) => a.as_bytes().cmp(&b.as_bytes()), + (Value::String(_), _) => Ordering::Less, + (_, Value::String(_)) => Ordering::Greater, + // Other variants can be ordered by their pointer + (a, b) => a.to_pointer().cmp(&b.to_pointer()), + } + } + + pub(crate) fn fmt_pretty( + &self, + fmt: &mut fmt::Formatter, + recursive: bool, + ident: usize, + visited: &mut HashSet<*const c_void>, + ) -> fmt::Result { + match self { + Value::Nil => write!(fmt, "nil"), + Value::Boolean(b) => write!(fmt, "{b}"), + Value::LightUserData(ud) if ud.0.is_null() => write!(fmt, "null"), + Value::LightUserData(ud) => write!(fmt, "lightuserdata: {:?}", ud.0), + Value::Integer(i) => write!(fmt, "{i}"), + Value::Number(n) => write!(fmt, "{n}"), + #[cfg(feature = "luau")] + Value::Vector(v) => write!(fmt, "{v}"), + Value::String(s) => write!(fmt, "{s:?}"), + Value::Table(t) if recursive && !visited.contains(&t.to_pointer()) => { + t.fmt_pretty(fmt, ident, visited) + } + t @ Value::Table(_) => write!(fmt, "table: {:?}", t.to_pointer()), + f @ Value::Function(_) => write!(fmt, "function: {:?}", f.to_pointer()), + t @ Value::Thread(_) => write!(fmt, "thread: {:?}", t.to_pointer()), + Value::UserData(ud) => ud.fmt_pretty(fmt), + #[cfg(feature = "luau")] + buf @ Value::Buffer(_) => write!(fmt, "buffer: {:?}", buf.to_pointer()), + Value::Error(e) if recursive => write!(fmt, "{e:?}"), + Value::Error(_) => write!(fmt, "error"), + Value::Other(v) => write!(fmt, "other: {:?}", v.to_pointer()), } - self.0.reverse(); - Ok(()) } } -/// Trait for types convertible to any number of Lua values. -/// -/// This is a generalization of `ToLua`, allowing any number of resulting Lua values instead of just -/// one. Any type that implements `ToLua` will automatically implement this trait. -pub trait ToLuaMulti<'lua> { - /// Performs the conversion. - fn to_lua_multi(self, lua: &'lua Lua) -> Result>; +impl fmt::Debug for Value { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + if fmt.alternate() { + return self.fmt_pretty(fmt, true, 0, &mut HashSet::new()); + } + + match self { + Value::Nil => write!(fmt, "Nil"), + Value::Boolean(b) => write!(fmt, "Boolean({b})"), + Value::LightUserData(ud) => write!(fmt, "{ud:?}"), + Value::Integer(i) => write!(fmt, "Integer({i})"), + Value::Number(n) => write!(fmt, "Number({n})"), + #[cfg(feature = "luau")] + Value::Vector(v) => write!(fmt, "{v:?}"), + Value::String(s) => write!(fmt, "String({s:?})"), + Value::Table(t) => write!(fmt, "{t:?}"), + Value::Function(f) => write!(fmt, "{f:?}"), + Value::Thread(t) => write!(fmt, "{t:?}"), + Value::UserData(ud) => write!(fmt, "{ud:?}"), + #[cfg(feature = "luau")] + Value::Buffer(buf) => write!(fmt, "{buf:?}"), + Value::Error(e) => write!(fmt, "Error({e:?})"), + Value::Other(v) => write!(fmt, "Other({v:?})"), + } + } } -/// Trait for types that can be created from an arbitrary number of Lua values. -/// -/// This is a generalization of `FromLua`, allowing an arbitrary number of Lua values to participate -/// in the conversion. Any type that implements `FromLua` will automatically implement this trait. -pub trait FromLuaMulti<'lua>: Sized { - /// Performs the conversion. - /// - /// In case `values` contains more values than needed to perform the conversion, the excess - /// values should be ignored. This reflects the semantics of Lua when calling a function or - /// assigning values. Similarly, if not enough values are given, conversions should assume that - /// any missing values are nil. - fn from_lua_multi(values: MultiValue<'lua>, lua: &'lua Lua) -> Result; +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Value::Nil, Value::Nil) => true, + (Value::Boolean(a), Value::Boolean(b)) => a == b, + (Value::LightUserData(a), Value::LightUserData(b)) => a == b, + (Value::Integer(a), Value::Integer(b)) => *a == *b, + (Value::Integer(a), Value::Number(b)) => *a as Number == *b, + (Value::Number(a), Value::Integer(b)) => *a == *b as Number, + (Value::Number(a), Value::Number(b)) => *a == *b, + #[cfg(feature = "luau")] + (Value::Vector(v1), Value::Vector(v2)) => v1 == v2, + (Value::String(a), Value::String(b)) => a == b, + (Value::Table(a), Value::Table(b)) => a == b, + (Value::Function(a), Value::Function(b)) => a == b, + (Value::Thread(a), Value::Thread(b)) => a == b, + (Value::UserData(a), Value::UserData(b)) => a == b, + #[cfg(feature = "luau")] + (Value::Buffer(a), Value::Buffer(b)) => a == b, + _ => false, + } + } +} + +/// A wrapped [`Value`] with customized serialization behavior. +#[cfg(feature = "serde")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] +pub struct SerializableValue<'a> { + value: &'a Value, + options: crate::serde::de::Options, + // In many cases we don't need `visited` map, so don't allocate memory by default + visited: Option>>>, +} + +#[cfg(feature = "serde")] +impl Serialize for Value { + #[inline] + fn serialize(&self, serializer: S) -> StdResult { + SerializableValue::new(self, Default::default(), None).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> SerializableValue<'a> { + #[inline] + pub(crate) fn new( + value: &'a Value, + options: crate::serde::de::Options, + visited: Option<&Rc>>>, + ) -> Self { + if let Value::Table(_) = value { + return Self { + value, + options, + // We need to always initialize the `visited` map for Tables + visited: visited.cloned().or_else(|| Some(Default::default())), + }; + } + Self { + value, + options, + visited: None, + } + } + + /// If true, an attempt to serialize types such as [`Function`], [`Thread`], [`LightUserData`] + /// and [`Error`] will cause an error. + /// Otherwise these types skipped when iterating or serialized as unit type. + /// + /// Default: **true** + #[must_use] + pub fn deny_unsupported_types(mut self, enabled: bool) -> Self { + self.options.deny_unsupported_types = enabled; + self + } + + /// If true, an attempt to serialize a recursive table (table that refers to itself) + /// will cause an error. + /// Otherwise subsequent attempts to serialize the same table will be ignored. + /// + /// Default: **true** + #[must_use] + pub fn deny_recursive_tables(mut self, enabled: bool) -> Self { + self.options.deny_recursive_tables = enabled; + self + } + + /// If true, keys in tables will be iterated (and serialized) in sorted order. + /// + /// Default: **false** + #[must_use] + pub fn sort_keys(mut self, enabled: bool) -> Self { + self.options.sort_keys = enabled; + self + } + + /// If true, empty Lua tables will be encoded as array, instead of map. + /// + /// Default: **false** + #[must_use] + pub fn encode_empty_tables_as_array(mut self, enabled: bool) -> Self { + self.options.encode_empty_tables_as_array = enabled; + self + } + + /// If true, enable detection of mixed tables. + /// + /// A mixed table is a table that has both array-like and map-like entries or several borders. + /// + /// Default: **false** + #[must_use] + pub fn detect_mixed_tables(mut self, enabled: bool) -> Self { + self.options.detect_mixed_tables = enabled; + self + } +} + +#[cfg(feature = "serde")] +impl Serialize for SerializableValue<'_> { + fn serialize(&self, serializer: S) -> StdResult + where + S: Serializer, + { + match self.value { + Value::Nil => serializer.serialize_unit(), + Value::Boolean(b) => serializer.serialize_bool(*b), + #[allow(clippy::useless_conversion)] + Value::Integer(i) => serializer.serialize_i64((*i).into()), + Value::Number(n) => serializer.serialize_f64(*n), + #[cfg(feature = "luau")] + Value::Vector(v) => v.serialize(serializer), + Value::String(s) => s.serialize(serializer), + Value::Table(t) => { + let visited = self.visited.as_ref().unwrap().clone(); + SerializableTable::new(t, self.options, visited).serialize(serializer) + } + Value::LightUserData(ud) if ud.0.is_null() => serializer.serialize_none(), + Value::UserData(ud) if ud.is_serializable() || self.options.deny_unsupported_types => { + ud.serialize(serializer) + } + #[cfg(feature = "luau")] + Value::Buffer(buf) => buf.serialize(serializer), + Value::Function(_) + | Value::Thread(_) + | Value::UserData(_) + | Value::LightUserData(_) + | Value::Error(_) + | Value::Other(_) => { + if self.options.deny_unsupported_types { + let msg = format!("cannot serialize <{}>", self.value.type_name()); + Err(ser::Error::custom(msg)) + } else { + serializer.serialize_unit() + } + } + } + } +} + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_any!(Value: Send); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(Value: Send, Sync); } diff --git a/src/vector.rs b/src/vector.rs new file mode 100644 index 00000000..c292e5d9 --- /dev/null +++ b/src/vector.rs @@ -0,0 +1,91 @@ +use std::fmt; + +#[cfg(feature = "serde")] +use serde::ser::{Serialize, SerializeTupleStruct, Serializer}; + +/// A Luau vector type. +/// +/// By default vectors are 3-dimensional, but can be 4-dimensional +/// if the `luau-vector4` feature is enabled. +#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] +#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd)] +pub struct Vector(pub(crate) [f32; Self::SIZE]); + +impl fmt::Display for Vector { + #[rustfmt::skip] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[cfg(not(feature = "luau-vector4"))] + return write!(f, "vector({}, {}, {})", self.x(), self.y(), self.z()); + #[cfg(feature = "luau-vector4")] + return write!(f, "vector({}, {}, {}, {})", self.x(), self.y(), self.z(), self.w()); + } +} + +#[cfg_attr(not(feature = "luau"), allow(unused))] +impl Vector { + pub(crate) const SIZE: usize = if cfg!(feature = "luau-vector4") { 4 } else { 3 }; + + /// Creates a new vector. + #[cfg(not(feature = "luau-vector4"))] + pub const fn new(x: f32, y: f32, z: f32) -> Self { + Self([x, y, z]) + } + + /// Creates a new vector. + #[cfg(feature = "luau-vector4")] + pub const fn new(x: f32, y: f32, z: f32, w: f32) -> Self { + Self([x, y, z, w]) + } + + /// Creates a new vector with all components set to `0.0`. + pub const fn zero() -> Self { + Self([0.0; Self::SIZE]) + } + + /// Returns 1st component of the vector. + pub const fn x(&self) -> f32 { + self.0[0] + } + + /// Returns 2nd component of the vector. + pub const fn y(&self) -> f32 { + self.0[1] + } + + /// Returns 3rd component of the vector. + pub const fn z(&self) -> f32 { + self.0[2] + } + + /// Returns 4th component of the vector. + #[cfg(any(feature = "luau-vector4", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau-vector4")))] + pub const fn w(&self) -> f32 { + self.0[3] + } +} + +#[cfg(feature = "serde")] +impl Serialize for Vector { + fn serialize(&self, serializer: S) -> std::result::Result { + let mut ts = serializer.serialize_tuple_struct("Vector", Self::SIZE)?; + ts.serialize_field(&self.x())?; + ts.serialize_field(&self.y())?; + ts.serialize_field(&self.z())?; + #[cfg(feature = "luau-vector4")] + ts.serialize_field(&self.w())?; + ts.end() + } +} + +impl PartialEq<[f32; Self::SIZE]> for Vector { + #[inline] + fn eq(&self, other: &[f32; Self::SIZE]) -> bool { + self.0 == *other + } +} + +#[cfg(feature = "luau")] +impl crate::types::LuaType for Vector { + const TYPE_ID: std::os::raw::c_int = ffi::LUA_TVECTOR; +} diff --git a/tarpaulin.toml b/tarpaulin.toml new file mode 100644 index 00000000..110456cf --- /dev/null +++ b/tarpaulin.toml @@ -0,0 +1,23 @@ +[lua55] +features = "lua55,vendored,async,send,serde,macros,anyhow,userdata-wrappers" + +[lua55_non_send] +features = "lua55,vendored,async,serde,macros,anyhow,userdata-wrappers" + +[lua55_with_memory_limit] +features = "lua55,vendored,async,send,serde,macros,anyhow,userdata-wrappers" +rustflags = "--cfg force_memory_limit" + +[lua51] +features = "lua51,vendored,async,send,serde,macros" + +[lua51_with_memory_limit] +features = "lua51,vendored,async,send,serde,macros" +rustflags = "--cfg force_memory_limit" + +[luau] +features = "luau,async,send,serde,macros" + +[luau_with_memory_limit] +features = "luau,async,send,serde,macros" +rustflags = "--cfg force_memory_limit" diff --git a/tests/async.rs b/tests/async.rs index fcfb9c7c..55e41593 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,27 +1,32 @@ #![cfg(feature = "async")] -use std::cell::Cell; -use std::rc::Rc; -use std::sync::{ - atomic::{AtomicI64, AtomicU64, Ordering}, - Arc, -}; +use std::sync::Arc; use std::time::Duration; -use futures_timer::Delay; use futures_util::stream::TryStreamExt; +use tokio::sync::Mutex; use mlua::{ - Error, Function, Lua, LuaOptions, Result, StdLib, Table, TableExt, Thread, UserData, - UserDataMethods, Value, + Error, Function, Lua, LuaOptions, MultiValue, ObjectLike, Result, StdLib, Table, UserData, + UserDataMethods, UserDataRef, Value, }; +#[cfg(not(target_arch = "wasm32"))] +async fn sleep_ms(ms: u64) { + tokio::time::sleep(Duration::from_millis(ms)).await; +} + +#[cfg(target_arch = "wasm32")] +async fn sleep_ms(_ms: u64) { + // I was unable to make sleep() work in wasm32-emscripten target + tokio::task::yield_now().await; +} + #[tokio::test] async fn test_async_function() -> Result<()> { let lua = Lua::new(); - let f = lua - .create_async_function(|_lua, (a, b, c): (i64, i64, i64)| async move { Ok((a + b) * c) })?; + let f = lua.create_async_function(|_lua, (a, b, c): (i64, i64, i64)| async move { Ok((a + b) * c) })?; lua.globals().set("f", f)?; let res: i64 = lua.load("f(1, 2, 3)").eval_async().await?; @@ -30,12 +35,64 @@ async fn test_async_function() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_async_function_wrap() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_async(|s: String| async move { + tokio::task::yield_now().await; + Ok::<_, Error>(s) + }); + lua.globals().set("f", f)?; + let res: String = lua.load(r#"f("hello")"#).eval_async().await?; + assert_eq!(res, "hello"); + + // Return error + let ferr = Function::wrap_async(|| async move { Err::<(), _>(Error::runtime("some async error")) }); + lua.globals().set("ferr", ferr)?; + lua.load( + r#" + local ok, err = pcall(ferr) + assert(not ok and tostring(err):find("some async error")) + "#, + ) + .exec_async() + .await + .unwrap(); + + Ok(()) +} + +#[tokio::test] +async fn test_async_function_wrap_raw() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_raw_async(|s: String| async move { + tokio::task::yield_now().await; + s + }); + lua.globals().set("f", f)?; + let res: String = lua.load(r#"f("hello")"#).eval_async().await?; + assert_eq!(res, "hello"); + + // Return error + let ferr = Function::wrap_raw_async(|| async move { + tokio::task::yield_now().await; + Err::<(), _>("some error") + }); + lua.globals().set("ferr", ferr)?; + let (_, err): (Value, String) = lua.load(r#"ferr()"#).eval_async().await?; + assert_eq!(err, "some error"); + + Ok(()) +} + #[tokio::test] async fn test_async_sleep() -> Result<()> { let lua = Lua::new(); let sleep = lua.create_async_function(move |_lua, n: u64| async move { - Delay::new(Duration::from_millis(n)).await; + sleep_ms(n).await; Ok(format!("elapsed:{}ms", n)) })?; lua.globals().set("sleep", sleep)?; @@ -51,22 +108,39 @@ async fn test_async_call() -> Result<()> { let lua = Lua::new(); let hello = lua.create_async_function(|_lua, name: String| async move { - Delay::new(Duration::from_millis(10)).await; + sleep_ms(10).await; Ok(format!("hello, {}!", name)) })?; - match hello.call::<_, ()>("alex") { + match hello.call::<()>("alex") { Err(Error::RuntimeError(_)) => {} - _ => panic!( - "non-async executing async function must fail on the yield stage with RuntimeError" - ), + err => panic!("expected `RuntimeError`, got {err:?}"), }; - assert_eq!(hello.call_async::<_, String>("alex").await?, "hello, alex!"); + assert_eq!(hello.call_async::("alex").await?, "hello, alex!"); // Executing non-async functions using async call is allowed let sum = lua.create_function(|_lua, (a, b): (i64, i64)| return Ok(a + b))?; - assert_eq!(sum.call_async::<_, i64>((5, 1)).await?, 6); + assert_eq!(sum.call_async::((5, 1)).await?, 6); + + Ok(()) +} + +#[tokio::test] +async fn test_async_call_many_returns() -> Result<()> { + let lua = Lua::new(); + + let hello = lua.create_async_function(|_lua, ()| async move { + sleep_ms(10).await; + Ok(("a", "b", "c", 1)) + })?; + + let vals = hello.call_async::(()).await?; + assert_eq!(vals.len(), 4); + assert_eq!(vals[0].to_string()?, "a"); + assert_eq!(vals[1].to_string()?, "b"); + assert_eq!(vals[2].to_string()?, "c"); + assert_eq!(vals[3], Value::Integer(1)); Ok(()) } @@ -94,7 +168,7 @@ async fn test_async_handle_yield() -> Result<()> { let lua = Lua::new(); let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move { - Delay::new(Duration::from_millis(10)).await; + sleep_ms(10).await; Ok(a + b) })?; @@ -124,7 +198,7 @@ async fn test_async_handle_yield() -> Result<()> { "#, ) .eval::()?; - assert_eq!(min.call_async::<_, i64>((-1, 1)).await?, -1); + assert_eq!(min.call_async::((-1, 1)).await?, -1); Ok(()) } @@ -152,10 +226,10 @@ async fn test_async_return_async_closure() -> Result<()> { let lua = Lua::new(); let f = lua.create_async_function(|lua, a: i64| async move { - Delay::new(Duration::from_millis(10)).await; + sleep_ms(10).await; let g = lua.create_async_function(move |_, b: i64| async move { - Delay::new(Duration::from_millis(10)).await; + sleep_ms(10).await; return Ok(a + b); })?; @@ -174,6 +248,38 @@ async fn test_async_return_async_closure() -> Result<()> { Ok(()) } +#[cfg(any(feature = "lua55", feature = "lua54"))] +#[tokio::test] +async fn test_async_lua54_to_be_closed() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + globals.set("close_count", 0)?; + + let code = r#" + local t = setmetatable({}, { + __close = function() + close_count = close_count + 1 + end + }) + error "test" + "#; + let f = lua.load(code).into_function()?; + + // Test close using call_async + let _ = f.call_async::<()>(()).await; + assert_eq!(globals.get::("close_count")?, 1); + + // Don't close by default when awaiting async threads + let co = lua.create_thread(f.clone())?; + let _ = co.clone().into_async::<()>(())?.await; + assert_eq!(globals.get::("close_count")?, 1); + let _ = co.reset(f); + assert_eq!(globals.get::("close_count")?, 2); + + Ok(()) +} + #[tokio::test] async fn test_async_thread_stream() -> Result<()> { let lua = Lua::new(); @@ -193,7 +299,7 @@ async fn test_async_thread_stream() -> Result<()> { .eval()?, )?; - let mut stream = thread.into_async::<_, i64>(1); + let mut stream = thread.into_async::(1)?; let mut sum = 0; while let Some(n) = stream.try_next().await? { sum += n; @@ -213,12 +319,12 @@ async fn test_async_thread() -> Result<()> { let f = lua.create_async_function(move |_lua, ()| { let cnt3 = cnt2.clone(); async move { - Delay::new(Duration::from_millis(*cnt3.as_ref())).await; + sleep_ms(*cnt3.as_ref()).await; Ok("done") } })?; - let res: String = lua.create_thread(f)?.into_async(()).await?; + let res: String = lua.create_thread(f)?.into_async(())?.await?; assert_eq!(res, "done"); @@ -229,108 +335,146 @@ async fn test_async_thread() -> Result<()> { Ok(()) } +#[test] +fn test_async_thread_capture() -> Result<()> { + let lua = Lua::new(); + + let f = lua.create_async_function(move |_lua, v: Value| async move { + tokio::task::yield_now().await; + drop(v); + Ok(()) + })?; + + let thread = lua.create_thread(f)?; + // After first resume, `v: Value` is captured in the coroutine + thread.resume::<()>("abc").unwrap(); + drop(thread); + + Ok(()) +} + #[tokio::test] -async fn test_async_table() -> Result<()> { - let options = LuaOptions::new().thread_cache_size(4); +async fn test_async_table_object_like() -> Result<()> { + let options = LuaOptions::new().thread_pool_size(4); let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; let table = lua.create_table()?; table.set("val", 10)?; let get_value = lua.create_async_function(|_, table: Table| async move { - Delay::new(Duration::from_millis(10)).await; - table.get::<_, i64>("val") + sleep_ms(10).await; + table.get::("val") })?; table.set("get_value", get_value)?; let set_value = lua.create_async_function(|_, (table, n): (Table, i64)| async move { - Delay::new(Duration::from_millis(10)).await; + sleep_ms(10).await; table.set("val", n) })?; table.set("set_value", set_value)?; + assert_eq!(table.call_async_method::("get_value", ()).await?, 10); + table.call_async_method::<()>("set_value", 15).await?; + assert_eq!(table.call_async_method::("get_value", ()).await?, 15); + + let metatable = lua.create_table()?; + metatable.set( + "__call", + lua.create_async_function(|_, table: Table| async move { + sleep_ms(10).await; + table.get::("val") + })?, + )?; + table.set_metatable(Some(metatable))?; + assert_eq!(table.call_async::(()).await.unwrap(), 15); + + match table.call_async_method::<()>("non_existent", ()).await { + Err(Error::RuntimeError(err)) => { + assert!(err.contains("attempt to call a nil value (function 'non_existent')")) + } + r => panic!("expected RuntimeError, got {r:?}"), + } + + Ok(()) +} + +#[tokio::test] +async fn test_async_thread_pool() -> Result<()> { + let options = LuaOptions::new().thread_pool_size(4); + let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; + + let error_f = lua.create_async_function(|_, ()| async move { + sleep_ms(10).await; + Err::<(), _>(Error::runtime("test")) + })?; + let sleep = lua.create_async_function(|_, n| async move { - Delay::new(Duration::from_millis(n)).await; + sleep_ms(n).await; Ok(format!("elapsed:{}ms", n)) })?; - table.set("sleep", sleep)?; - assert_eq!( - table - .call_async_method::<_, _, i64>("get_value", ()) - .await?, - 10 - ); - table.call_async_method("set_value", 15).await?; - assert_eq!( - table - .call_async_method::<_, _, i64>("get_value", ()) - .await?, - 15 - ); - assert_eq!( - table - .call_async_function::<_, _, String>("sleep", 7) - .await?, - "elapsed:7ms" - ); + assert!(error_f.call_async::<()>(()).await.is_err()); + // Next call should use cached thread + assert_eq!(sleep.call_async::(3).await?, "elapsed:3ms"); Ok(()) } #[tokio::test] async fn test_async_userdata() -> Result<()> { - #[derive(Clone)] - struct MyUserData(Arc); + struct MyUserdata(u64); - impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + impl UserData for MyUserdata { + fn add_methods>(methods: &mut M) { methods.add_async_method("get_value", |_, data, ()| async move { - Delay::new(Duration::from_millis(10)).await; - Ok(data.0.load(Ordering::Relaxed)) + sleep_ms(10).await; + Ok(data.0) }); - methods.add_async_method("set_value", |_, data, n| async move { - Delay::new(Duration::from_millis(10)).await; - data.0.store(n, Ordering::Relaxed); + methods.add_async_method_mut("set_value", |_, mut data, n| async move { + sleep_ms(10).await; + data.0 = n; Ok(()) }); + methods.add_async_method_once("take_value", |_, data, ()| async move { + sleep_ms(10).await; + Ok(data.0) + }); + methods.add_async_function("sleep", |_, n| async move { - Delay::new(Duration::from_millis(n)).await; + sleep_ms(n).await; Ok(format!("elapsed:{}ms", n)) }); #[cfg(not(any(feature = "lua51", feature = "luau")))] methods.add_async_meta_method(mlua::MetaMethod::Call, |_, data, ()| async move { - let n = data.0.load(Ordering::Relaxed); - Delay::new(Duration::from_millis(n)).await; + let n = data.0; + sleep_ms(n).await; Ok(format!("elapsed:{}ms", n)) }); #[cfg(not(any(feature = "lua51", feature = "luau")))] - methods.add_async_meta_method( - mlua::MetaMethod::Index, - |_, data, key: String| async move { - Delay::new(Duration::from_millis(10)).await; - match key.as_str() { - "ms" => Ok(Some(data.0.load(Ordering::Relaxed) as f64)), - "s" => Ok(Some((data.0.load(Ordering::Relaxed) as f64) / 1000.0)), - _ => Ok(None), - } - }, - ); + methods.add_async_meta_method(mlua::MetaMethod::Index, |_, data, key: String| async move { + sleep_ms(10).await; + match key.as_str() { + "ms" => Ok(Some(data.0 as f64)), + "s" => Ok(Some((data.0 as f64) / 1000.0)), + _ => Ok(None), + } + }); #[cfg(not(any(feature = "lua51", feature = "luau")))] - methods.add_async_meta_method( + methods.add_async_meta_method_mut( mlua::MetaMethod::NewIndex, - |_, data, (key, value): (String, f64)| async move { - Delay::new(Duration::from_millis(10)).await; + |_, mut data, (key, value): (String, f64)| async move { + sleep_ms(10).await; match key.as_str() { - "ms" => Ok(data.0.store(value as u64, Ordering::Relaxed)), - "s" => Ok(data.0.store((value * 1000.0) as u64, Ordering::Relaxed)), - _ => Err(Error::external(format!("key '{}' not found", key))), + "ms" => data.0 = value as u64, + "s" => data.0 = (value * 1000.0) as u64, + _ => return Err(Error::external(format!("key '{}' not found", key))), } + Ok(()) }, ); } @@ -339,8 +483,8 @@ async fn test_async_userdata() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); - let userdata = lua.create_userdata(MyUserData(Arc::new(AtomicU64::new(11))))?; - globals.set("userdata", userdata.clone())?; + let userdata = lua.create_userdata(MyUserdata(11))?; + globals.set("userdata", &userdata)?; lua.load( r#" @@ -369,123 +513,208 @@ async fn test_async_userdata() -> Result<()> { .exec_async() .await?; + // ObjectLike methods + userdata.call_async_method::<()>("set_value", 24).await?; + let n: u64 = userdata.call_async_method("get_value", ()).await?; + assert_eq!(n, 24); + userdata.call_async_function::<()>("sleep", 15).await?; + + #[cfg(not(any(feature = "lua51", feature = "luau")))] + assert_eq!(userdata.call_async::(()).await?, "elapsed:24ms"); + + // Take value + let userdata2 = lua.create_userdata(MyUserdata(0))?; + globals.set("userdata2", userdata2)?; + lua.load("assert(userdata:take_value() == 24)") + .exec_async() + .await?; + match lua.load("userdata2.take_value(userdata)").exec_async().await { + Err(Error::CallbackError { cause, .. }) => { + let err = cause.to_string(); + assert!(err.contains("bad argument `self` to `MyUserdata.take_value`")); + assert!(err.contains("userdata has been destructed")); + } + r => panic!("expected Err(CallbackError), got {r:?}"), + } + Ok(()) } #[tokio::test] -async fn test_async_scope() -> Result<()> { - let ref lua = Lua::new(); +async fn test_async_thread_error() -> Result<()> { + struct MyUserData; + + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { + methods.add_meta_method("__tostring", |_, _this, ()| Ok("myuserdata error")) + } + } - let ref rc = Rc::new(Cell::new(0)); + let lua = Lua::new(); + let result = lua + .load("function x(...) error(...) end x(...)") + .set_name("chunk") + .call_async::<()>(MyUserData) + .await; + assert!( + matches!(result, Err(Error::RuntimeError(cause)) if cause.contains("myuserdata error")), + "improper error traceback from dead thread" + ); - let fut = lua.async_scope(|scope| async move { - let f = scope.create_async_function(move |_, n: u64| { - let rc2 = rc.clone(); + Ok(()) +} + +#[tokio::test] +async fn test_async_terminate() -> Result<()> { + // Future captures `Lua` instance and dropped all together + let mutex = Arc::new(Mutex::new(0u32)); + { + let lua = Lua::new(); + let mutex2 = mutex.clone(); + let func = lua.create_async_function(move |lua, ()| { + let mutex = mutex2.clone(); async move { - rc2.set(42); - Delay::new(Duration::from_millis(n)).await; - assert_eq!(Rc::strong_count(&rc2), 2); + let _guard = mutex.lock().await; + sleep_ms(100).await; + drop(lua); // Move Lua to the future to test drop Ok(()) } })?; - lua.globals().set("f", f.clone())?; - - assert_eq!(Rc::strong_count(rc), 1); - let _ = f.call_async::(10).await?; - assert_eq!(Rc::strong_count(rc), 1); - - // Create future in partialy polled state (Poll::Pending) - let g = lua.create_thread(f)?; - g.resume::(10)?; - lua.globals().set("g", g)?; - assert_eq!(Rc::strong_count(rc), 2); + let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(())).await; + } + assert!(mutex.try_lock().is_ok()); + // Future is dropped, but `Lua` instance is still alive + let lua = Lua::new(); + let func = lua.create_async_function(move |_, mutex: UserDataRef>>| async move { + let _guard = mutex.lock().await; + sleep_ms(100).await; Ok(()) - }); + })?; + let mutex2 = lua.create_any_userdata(mutex.clone())?; + let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(mutex2)).await; + assert!(mutex.try_lock().is_ok()); - assert_eq!(Rc::strong_count(rc), 1); - let _ = fut.await?; - assert_eq!(Rc::strong_count(rc), 1); + Ok(()) +} - match lua - .globals() - .get::<_, Function>("f")? - .call_async::<_, ()>(10) - .await - { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), - }, - r => panic!("improper return for destructed function: {:?}", r), - }; +#[tokio::test] +async fn test_async_task() -> Result<()> { + let lua = Lua::new(); - match lua.globals().get::<_, Thread>("g")?.resume::<_, Value>(()) { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), - }, - r => panic!("improper return for destructed function: {:?}", r), - }; + let delay = lua.create_function(|lua, (secs, f, args): (f32, Function, MultiValue)| { + let thread = lua.create_thread(f)?; + let thread2 = thread.clone().into_async::<()>(args)?; + tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_secs_f32(secs)).await; + _ = thread2.await; + }); + Ok(thread) + })?; + + lua.globals().set("delay", delay)?; + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + _ = lua + .load("delay(0.1, function(msg) global_msg = msg end, 'done')") + .exec_async() + .await; + }) + .await; + local.await; + assert_eq!(lua.globals().get::("global_msg")?, "done"); Ok(()) } #[tokio::test] -async fn test_async_scope_userdata() -> Result<()> { - #[derive(Clone)] - struct MyUserData(Arc); +async fn test_async_task_abort() -> Result<()> { + let lua = Lua::new(); - impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_method("get_value", |_, data, ()| async move { - Delay::new(Duration::from_millis(10)).await; - Ok(data.0.load(Ordering::Relaxed)) - }); + let sleep = lua.create_async_function(move |_lua, n: u64| async move { + sleep_ms(n).await; + Ok(()) + })?; + lua.globals().set("sleep", sleep)?; - methods.add_async_method("set_value", |_, data, n| async move { - Delay::new(Duration::from_millis(10)).await; - data.0.store(n, Ordering::Relaxed); - Ok(()) + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let lua2 = lua.clone(); + let jh = tokio::task::spawn_local(async move { + lua2.load("sleep(200) result = 'done'") + .exec_async() + .await + .unwrap(); }); + sleep_ms(100).await; // Wait for the task to start + jh.abort(); + }) + .await; + local.await; + assert_eq!(lua.globals().get::("result")?, Value::Nil); - methods.add_async_function("sleep", |_, n| async move { - Delay::new(Duration::from_millis(n)).await; - Ok(format!("elapsed:{}ms", n)) - }); + Ok(()) +} + +#[tokio::test] +#[cfg(not(feature = "luau"))] +async fn test_async_hook() -> Result<()> { + use std::sync::atomic::{AtomicBool, Ordering}; + + let lua = Lua::new(); + + static HOOK_CALLED: AtomicBool = AtomicBool::new(false); + lua.set_global_hook(mlua::HookTriggers::new().every_line(), move |_, _| { + if !HOOK_CALLED.swap(true, Ordering::Relaxed) { + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + return Ok(mlua::VmState::Yield); } - } + Ok(mlua::VmState::Continue) + })?; - let ref lua = Lua::new(); + let sleep = lua.create_async_function(move |_lua, n: u64| async move { + sleep_ms(n).await; + Ok(()) + })?; + lua.globals().set("sleep", sleep)?; - let ref arc = Arc::new(AtomicI64::new(11)); + lua.load(r"sleep(100)").exec_async().await?; + assert!(HOOK_CALLED.load(Ordering::Relaxed)); - lua.async_scope(|scope| async move { - let ud = scope.create_userdata(MyUserData(arc.clone()))?; - lua.globals().set("userdata", ud)?; - lua.load( - r#" - assert(userdata:get_value() == 11) - userdata:set_value(12) - assert(userdata.sleep(5) == "elapsed:5ms") - assert(userdata:get_value() == 12) - "#, - ) - .exec_async() - .await - }) - .await?; + Ok(()) +} - assert_eq!(Arc::strong_count(arc), 1); +#[test] +fn test_async_yield_with() -> Result<()> { + let lua = Lua::new(); - match lua.load("userdata:get_value()").exec_async().await { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - e => panic!("expected `CallbackDestructed` error cause, got {:?}", e), - }, - r => panic!("improper return for destructed userdata: {:?}", r), - }; + let func = lua.create_async_function(|lua, (mut a, mut b): (i32, i32)| async move { + let zero = lua.yield_with::(()).await?; + assert!(zero.is_empty()); + let one = lua.yield_with::(a + b).await?; + assert_eq!(one.len(), 1); + + for _ in 0..3 { + (a, b) = lua.yield_with((a + b, a * b)).await?; + } + Ok((0, 0)) + })?; + + let thread = lua.create_thread(func)?; + + let zero = thread.resume::((2, 3))?; // function arguments + assert!(zero.is_empty()); + let one = thread.resume::(())?; // value of "zero" is passed here + assert_eq!(one, 5); + + assert_eq!(thread.resume::<(i32, i32)>(1)?, (5, 6)); // value of "one" is passed here + assert_eq!(thread.resume::<(i32, i32)>((10, 11))?, (21, 110)); + assert_eq!(thread.resume::<(i32, i32)>((11, 12))?, (23, 132)); + assert_eq!(thread.resume::<(i32, i32)>((12, 13))?, (0, 0)); + assert!(thread.is_finished()); Ok(()) } diff --git a/tests/buffer.rs b/tests/buffer.rs new file mode 100644 index 00000000..3f07569f --- /dev/null +++ b/tests/buffer.rs @@ -0,0 +1,123 @@ +#![cfg(feature = "luau")] + +use std::io::{Read, Seek, SeekFrom, Write}; + +use mlua::{Lua, Result, Value}; + +#[test] +fn test_buffer() -> Result<()> { + let lua = Lua::new(); + + let buf1 = lua + .load( + r#" + local buf = buffer.fromstring("hello") + assert(buffer.len(buf) == 5) + return buf + "#, + ) + .eval::()?; + assert!(buf1.is_buffer()); + assert_eq!(buf1.type_name(), "buffer"); + + let buf2 = lua.load("buffer.fromstring('hello')").eval::()?; + assert_ne!(buf1, buf2); + + // Check that we can pass buffer type to Lua + let buf1 = buf1.as_buffer().unwrap(); + let func = lua.create_function(|_, buf: Value| return buf.to_string())?; + assert!(func.call::(buf1)?.starts_with("buffer:")); + + // Check buffer methods + assert_eq!(buf1.len(), 5); + assert_eq!(buf1.to_vec(), b"hello"); + assert_eq!(buf1.read_bytes::<3>(1), [b'e', b'l', b'l']); + buf1.write_bytes(1, b"i"); + assert_eq!(buf1.to_vec(), b"hillo"); + + let buf3 = lua.create_buffer(b"")?; + assert!(buf3.is_empty()); + assert!(!Value::Buffer(buf3).to_pointer().is_null()); + + Ok(()) +} + +#[test] +#[should_panic(expected = "out of range for slice of length 13")] +fn test_buffer_out_of_bounds_read() { + let lua = Lua::new(); + let buf = lua.create_buffer(b"hello, world!").unwrap(); + _ = buf.read_bytes::<1>(13); +} + +#[test] +#[should_panic(expected = "out of range for slice of length 13")] +fn test_buffer_out_of_bounds_write() { + let lua = Lua::new(); + let buf = lua.create_buffer(b"hello, world!").unwrap(); + buf.write_bytes(14, b"!!"); +} + +#[test] +fn create_large_buffer() { + let lua = Lua::new(); + let err = lua.create_buffer_with_capacity(1_073_741_824 + 1).unwrap_err(); // 1GB + assert!(err.to_string().contains("memory allocation error")); + + // Normal buffer is okay + let buf = lua.create_buffer_with_capacity(1024 * 1024).unwrap(); + assert_eq!(buf.len(), 1024 * 1024); +} + +#[test] +fn test_buffer_cursor() -> Result<()> { + let lua = Lua::new(); + let mut cursor = lua.create_buffer(b"hello, world")?.cursor(); + + let mut data = Vec::new(); + cursor.read_to_end(&mut data)?; + assert_eq!(data, b"hello, world"); + + // No more data to read + let mut one = [0u8; 1]; + assert_eq!(cursor.read(&mut one)?, 0); + + // Seek to start + cursor.seek(SeekFrom::Start(0))?; + cursor.read_exact(&mut one)?; + assert_eq!(one, [b'h']); + + // Seek to end -5 + cursor.seek(SeekFrom::End(-5))?; + let mut five = [0u8; 5]; + cursor.read_exact(&mut five)?; + assert_eq!(&five, b"world"); + + // Seek to current -1 + cursor.seek(SeekFrom::Current(-1))?; + cursor.read_exact(&mut one)?; + assert_eq!(one, [b'd']); + + // Invalid seek + assert!(cursor.seek(SeekFrom::Current(-100)).is_err()); + assert!(cursor.seek(SeekFrom::End(1)).is_err()); + + // Write data + let buf = lua.create_buffer_with_capacity(100)?; + cursor = buf.clone().cursor(); + + cursor.write_all(b"hello, ...")?; + cursor.seek(SeekFrom::Current(-3))?; + cursor.write_all(b"Rust!")?; + + assert_eq!(&buf.read_bytes::<12>(0), b"hello, Rust!"); + + // Writing beyond the end of the buffer does nothing + cursor.seek(SeekFrom::End(0))?; + assert_eq!(cursor.write(b".")?, 0); + + // Flush is no-op + cursor.flush()?; + + Ok(()) +} diff --git a/tests/byte_string.rs b/tests/byte_string.rs index 48be2d9b..76e43e14 100644 --- a/tests/byte_string.rs +++ b/tests/byte_string.rs @@ -22,38 +22,38 @@ fn test_byte_string_round_trip() -> Result<()> { let globals = lua.globals(); - let isi = globals.get::<_, BString>("invalid_sequence_identifier")?; + let isi = globals.get::("invalid_sequence_identifier")?; assert_eq!(isi, [0xa0, 0xa1].as_ref()); - let i2os2 = globals.get::<_, BString>("invalid_2_octet_sequence_2nd")?; + let i2os2 = globals.get::("invalid_2_octet_sequence_2nd")?; assert_eq!(i2os2, [0xc3, 0x28].as_ref()); - let i3os2 = globals.get::<_, BString>("invalid_3_octet_sequence_2nd")?; + let i3os2 = globals.get::("invalid_3_octet_sequence_2nd")?; assert_eq!(i3os2, [0xe2, 0x28, 0xa1].as_ref()); - let i3os3 = globals.get::<_, BString>("invalid_3_octet_sequence_3rd")?; + let i3os3 = globals.get::("invalid_3_octet_sequence_3rd")?; assert_eq!(i3os3, [0xe2, 0x82, 0x28].as_ref()); - let i4os2 = globals.get::<_, BString>("invalid_4_octet_sequence_2nd")?; + let i4os2 = globals.get::("invalid_4_octet_sequence_2nd")?; assert_eq!(i4os2, [0xf0, 0x28, 0x8c, 0xbc].as_ref()); - let i4os3 = globals.get::<_, BString>("invalid_4_octet_sequence_3rd")?; + let i4os3 = globals.get::("invalid_4_octet_sequence_3rd")?; assert_eq!(i4os3, [0xf0, 0x90, 0x28, 0xbc].as_ref()); - let i4os4 = globals.get::<_, BString>("invalid_4_octet_sequence_4th")?; + let i4os4 = globals.get::("invalid_4_octet_sequence_4th")?; assert_eq!(i4os4, [0xf0, 0x28, 0x8c, 0x28].as_ref()); - let aas = globals.get::<_, BString>("an_actual_string")?; + let aas = globals.get::("an_actual_string")?; assert_eq!(aas, b"Hello, world!".as_ref()); - globals.set::<_, &BStr>("bstr_invalid_sequence_identifier", isi.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_2_octet_sequence_2nd", i2os2.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_3_octet_sequence_2nd", i3os2.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_3_octet_sequence_3rd", i3os3.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_4_octet_sequence_2nd", i4os2.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_4_octet_sequence_3rd", i4os3.as_ref())?; - globals.set::<_, &BStr>("bstr_invalid_4_octet_sequence_4th", i4os4.as_ref())?; - globals.set::<_, &BStr>("bstr_an_actual_string", aas.as_ref())?; + globals.set("bstr_invalid_sequence_identifier", isi.as_ref() as &BStr)?; + globals.set("bstr_invalid_2_octet_sequence_2nd", i2os2.as_ref() as &BStr)?; + globals.set("bstr_invalid_3_octet_sequence_2nd", i3os2.as_ref() as &BStr)?; + globals.set("bstr_invalid_3_octet_sequence_3rd", i3os3.as_ref() as &BStr)?; + globals.set("bstr_invalid_4_octet_sequence_2nd", i4os2.as_ref() as &BStr)?; + globals.set("bstr_invalid_4_octet_sequence_3rd", i4os3.as_ref() as &BStr)?; + globals.set("bstr_invalid_4_octet_sequence_4th", i4os4.as_ref() as &BStr)?; + globals.set("bstr_an_actual_string", aas.as_ref() as &BStr)?; lua.load( r#" @@ -69,14 +69,14 @@ fn test_byte_string_round_trip() -> Result<()> { ) .exec()?; - globals.set::<_, BString>("bstring_invalid_sequence_identifier", isi)?; - globals.set::<_, BString>("bstring_invalid_2_octet_sequence_2nd", i2os2)?; - globals.set::<_, BString>("bstring_invalid_3_octet_sequence_2nd", i3os2)?; - globals.set::<_, BString>("bstring_invalid_3_octet_sequence_3rd", i3os3)?; - globals.set::<_, BString>("bstring_invalid_4_octet_sequence_2nd", i4os2)?; - globals.set::<_, BString>("bstring_invalid_4_octet_sequence_3rd", i4os3)?; - globals.set::<_, BString>("bstring_invalid_4_octet_sequence_4th", i4os4)?; - globals.set::<_, BString>("bstring_an_actual_string", aas)?; + globals.set("bstring_invalid_sequence_identifier", isi)?; + globals.set("bstring_invalid_2_octet_sequence_2nd", i2os2)?; + globals.set("bstring_invalid_3_octet_sequence_2nd", i3os2)?; + globals.set("bstring_invalid_3_octet_sequence_3rd", i3os3)?; + globals.set("bstring_invalid_4_octet_sequence_2nd", i4os2)?; + globals.set("bstring_invalid_4_octet_sequence_3rd", i4os3)?; + globals.set("bstring_invalid_4_octet_sequence_4th", i4os4)?; + globals.set("bstring_an_actual_string", aas)?; lua.load( r#" diff --git a/tests/chunk.rs b/tests/chunk.rs index 4f87d603..95c54274 100644 --- a/tests/chunk.rs +++ b/tests/chunk.rs @@ -1,12 +1,37 @@ -use std::fs; -use std::io; +#[cfg(not(target_os = "wasi"))] +use std::{fs, io}; -use mlua::{Error, Lua, Result}; +use mlua::{Chunk, ChunkMode, Lua, Result}; #[test] +fn test_chunk_methods() -> Result<()> { + let lua = Lua::new(); + + #[cfg(unix)] + assert!(lua.load("return 123").name().starts_with("@tests/chunk.rs")); + let chunk2 = lua.load("return 123").set_name("@new_name"); + assert_eq!(chunk2.name(), "@new_name"); + + let env = lua.create_table_from([("a", 987)])?; + let chunk3 = lua.load("return a").set_environment(env.clone()); + assert_eq!(chunk3.environment().unwrap(), &env); + assert_eq!(chunk3.mode(), ChunkMode::Text); + assert_eq!(chunk3.call::(())?, 987); + + Ok(()) +} + +#[test] +#[cfg(not(target_os = "wasi"))] fn test_chunk_path() -> Result<()> { let lua = Lua::new(); + if cfg!(target_arch = "wasm32") { + // TODO: figure out why emscripten fails on file operations + // Also see https://github.com/rust-lang/rust/issues/119250 + return Ok(()); + } + let temp_dir = tempfile::tempdir().unwrap(); fs::write( temp_dir.path().join("module.lua"), @@ -14,15 +39,38 @@ fn test_chunk_path() -> Result<()> { return 321 "#, )?; - let i: i32 = lua.load(&temp_dir.path().join("module.lua")).eval()?; + let i: i32 = lua.load(temp_dir.path().join("module.lua")).eval()?; assert_eq!(i, 321); - match lua.load(&temp_dir.path().join("module2.lua")).exec() { - Err(Error::ExternalError(err)) - if err.downcast_ref::().unwrap().kind() == io::ErrorKind::NotFound => {} + match lua.load(&*temp_dir.path().join("module2.lua")).exec() { + Err(err) if err.downcast_ref::().unwrap().kind() == io::ErrorKind::NotFound => {} res => panic!("expected io::Error, got {:?}", res), }; + // &Path + assert_eq!( + (lua.load(&*temp_dir.path().join("module.lua").as_path())).eval::()?, + 321 + ); + + Ok(()) +} + +#[test] +fn test_chunk_impls() -> Result<()> { + let lua = Lua::new(); + + // StdString + assert_eq!(lua.load(String::from("1")).eval::()?, 1); + assert_eq!(lua.load(&String::from("2")).eval::()?, 2); + + // &[u8] + assert_eq!(lua.load(&b"3"[..]).eval::()?, 3); + + // Vec + assert_eq!(lua.load(b"4".to_vec()).eval::()?, 4); + assert_eq!(lua.load(&b"5".to_vec()).eval::()?, 5); + Ok(()) } @@ -37,18 +85,98 @@ fn test_chunk_macro() -> Result<()> { let data = lua.create_table()?; data.raw_set("num", 1)?; + let ud = mlua::AnyUserData::wrap("hello"); + let f = mlua::Function::wrap(|| Ok::<_, mlua::Error>(())); + lua.globals().set("g", 123)?; + let string = String::new(); + let str = string.as_str(); + lua.load(mlua::chunk! { assert($name == "Rustacean") + assert(type($table) == "table") assert($table[1] == 1) + assert(type($data) == "table") assert($data.num == 1) + assert(type($ud) == "userdata") + assert(type($f) == "function") + assert(type($str) == "string") + assert($str == "") assert(g == 123) s = 321 }) .exec()?; - assert_eq!(lua.globals().get::<_, i32>("s")?, 321); + assert_eq!(lua.globals().get::("s")?, 321); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_compiler() -> Result<()> { + let compiler = mlua::Compiler::new() + .set_optimization_level(2) + .set_debug_level(2) + .set_type_info_level(1) + .set_coverage_level(2) + .set_vector_ctor("vector.new") + .set_vector_type("vector") + .set_mutable_globals(["mutable_global"]) + .set_userdata_types(["MyUserdata"]) + .set_disabled_builtins(["tostring"]); + + assert!(compiler.compile("return tostring(vector.new(1, 2, 3))").is_ok()); + + // Error + match compiler.compile("%") { + Err(mlua::Error::SyntaxError { ref message, .. }) => { + assert!(message.contains("Expected identifier when parsing expression, got '%'"),); + } + res => panic!("expected result: {res:?}"), + } + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_compiler_library_constants() { + use mlua::{Compiler, Vector}; + + let compiler = Compiler::new() + .set_optimization_level(2) + .add_library_constant("mylib.const_bool", true) + .add_library_constant("mylib.const_num", 123.0) + .add_library_constant("mylib.const_vec", Vector::zero()) + .add_library_constant("mylib.const_str", "value1"); + + let lua = Lua::new(); + lua.set_compiler(compiler); + let const_bool = lua.load("return mylib.const_bool").eval::().unwrap(); + assert_eq!(const_bool, true); + let const_num = lua.load("return mylib.const_num").eval::().unwrap(); + assert_eq!(const_num, 123.0); + let const_vec = lua.load("return mylib.const_vec").eval::().unwrap(); + assert_eq!(const_vec, Vector::zero()); + let const_str = lua.load("return mylib.const_str").eval::(); + assert_eq!(const_str.unwrap(), "value1"); +} + +#[test] +fn test_chunk_wrap() -> Result<()> { + let lua = Lua::new(); + + let f = Chunk::wrap("return 123"); + lua.globals().set("f", f)?; + lua.load("assert(f() == 123)").exec().unwrap(); + + lua.globals().set("f2", Chunk::wrap("c()"))?; + assert!( + (lua.load("f2()").exec().err().unwrap().to_string()).contains(file!()), + "wrong chunk location" + ); Ok(()) } diff --git a/tests/compile.rs b/tests/compile.rs index e4d822a2..c8ee4511 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -7,15 +7,15 @@ fn test_compilation() { t.compile_fail("tests/compile/lua_norefunwindsafe.rs"); t.compile_fail("tests/compile/ref_nounwindsafe.rs"); t.compile_fail("tests/compile/scope_callback_capture.rs"); - t.compile_fail("tests/compile/scope_callback_inner.rs"); - t.compile_fail("tests/compile/scope_callback_outer.rs"); t.compile_fail("tests/compile/scope_invariance.rs"); t.compile_fail("tests/compile/scope_mutable_aliasing.rs"); t.compile_fail("tests/compile/scope_userdata_borrow.rs"); - t.compile_fail("tests/compile/static_callback_args.rs"); #[cfg(feature = "async")] - t.compile_fail("tests/compile/async_nonstatic_userdata.rs"); + { + t.compile_fail("tests/compile/async_any_userdata_method.rs"); + t.compile_fail("tests/compile/async_nonstatic_userdata.rs"); + } #[cfg(feature = "send")] t.compile_fail("tests/compile/non_send.rs"); diff --git a/tests/compile/async_any_userdata_method.rs b/tests/compile/async_any_userdata_method.rs new file mode 100644 index 00000000..680eaebd --- /dev/null +++ b/tests/compile/async_any_userdata_method.rs @@ -0,0 +1,15 @@ +use mlua::{Lua, UserDataMethods}; + +fn main() { + let lua = Lua::new(); + + lua.register_userdata_type::(|reg| { + let s = String::new(); + let mut s = &s; + reg.add_async_method("t", |_, this, ()| async { + s = &*this; + Ok(()) + }); + }) + .unwrap(); +} diff --git a/tests/compile/async_any_userdata_method.stderr b/tests/compile/async_any_userdata_method.stderr new file mode 100644 index 00000000..3e01c45b --- /dev/null +++ b/tests/compile/async_any_userdata_method.stderr @@ -0,0 +1,90 @@ +error[E0596]: cannot borrow `s` as mutable, as it is a captured variable in a `Fn` closure + --> tests/compile/async_any_userdata_method.rs:9:49 + | + 8 | let mut s = &s; + | ----- `s` declared here, outside the closure + 9 | reg.add_async_method("t", |_, this, ()| async { + | ------------- ^^^^^ cannot borrow as mutable + | | + | in this closure +10 | s = &*this; + | - mutable borrow occurs due to use of `s` in closure + +error[E0373]: async block may outlive the current function, but it borrows `this`, which is owned by the current function + --> tests/compile/async_any_userdata_method.rs:9:49 + | + 9 | reg.add_async_method("t", |_, this, ()| async { + | ^^^^^ may outlive borrowed value `this` +10 | s = &*this; + | ---- `this` is borrowed here + | +note: async block is returned here + --> tests/compile/async_any_userdata_method.rs:9:49 + | + 9 | reg.add_async_method("t", |_, this, ()| async { + | _________________________________________________^ +10 | | s = &*this; +11 | | Ok(()) +12 | | }); + | |_________^ +help: to force the async block to take ownership of `this` (and any other referenced variables), use the `move` keyword + | + 9 | reg.add_async_method("t", |_, this, ()| async move { + | ++++ + +error: lifetime may not live long enough + --> tests/compile/async_any_userdata_method.rs:9:49 + | + 9 | reg.add_async_method("t", |_, this, ()| async { + | ___________________________________-------------_^ + | | | | + | | | return type of closure `{async block@$DIR/tests/compile/async_any_userdata_method.rs:9:49: 9:54}` contains a lifetime `'2` + | | lifetime `'1` represents this closure's body +10 | | s = &*this; +11 | | Ok(()) +12 | | }); + | |_________^ returning this value requires that `'1` must outlive `'2` + | + = note: closure implements `Fn`, so references to captured variables can't escape the closure + +error[E0597]: `s` does not live long enough + --> tests/compile/async_any_userdata_method.rs:8:21 + | + 7 | let s = String::new(); + | - binding `s` declared here + 8 | let mut s = &s; + | ^^ borrowed value does not live long enough + 9 | / reg.add_async_method("t", |_, this, ()| async { +10 | | s = &*this; +11 | | Ok(()) +12 | | }); + | |__________- argument requires that `s` is borrowed for `'static` +13 | }) + | - `s` dropped here while still borrowed + | +note: requirement that the value outlives `'static` introduced here + --> src/userdata.rs + | + | M: Fn(Lua, UserDataRef, A) -> MR + MaybeSend + 'static, + | ^^^^^^^ + +error[E0373]: closure may outlive the current function, but it borrows `s`, which is owned by the current function + --> tests/compile/async_any_userdata_method.rs:9:35 + | + 9 | reg.add_async_method("t", |_, this, ()| async { + | ^^^^^^^^^^^^^ may outlive borrowed value `s` +10 | s = &*this; + | - `s` is borrowed here + | +note: function requires argument type to outlive `'static` + --> tests/compile/async_any_userdata_method.rs:9:9 + | + 9 | / reg.add_async_method("t", |_, this, ()| async { +10 | | s = &*this; +11 | | Ok(()) +12 | | }); + | |__________^ +help: to force the closure to take ownership of `s` (and any other referenced variables), use the `move` keyword + | + 9 | reg.add_async_method("t", move |_, this, ()| async { + | ++++ diff --git a/tests/compile/async_nonstatic_userdata.rs b/tests/compile/async_nonstatic_userdata.rs index 8aede321..d4a73eb9 100644 --- a/tests/compile/async_nonstatic_userdata.rs +++ b/tests/compile/async_nonstatic_userdata.rs @@ -1,13 +1,11 @@ -use mlua::{Lua, UserData, UserDataMethods}; +use mlua::{UserData, UserDataMethods}; fn main() { - let ref lua = Lua::new(); - #[derive(Clone)] struct MyUserData<'a>(&'a i64); - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + impl UserData for MyUserData<'_> { + fn add_methods>(methods: &mut M) { methods.add_async_method("print", |_, data, ()| async move { println!("{}", data.0); Ok(()) diff --git a/tests/compile/async_nonstatic_userdata.stderr b/tests/compile/async_nonstatic_userdata.stderr index c03a56a5..18412e4a 100644 --- a/tests/compile/async_nonstatic_userdata.stderr +++ b/tests/compile/async_nonstatic_userdata.stderr @@ -1,41 +1,11 @@ -error[E0495]: cannot infer an appropriate lifetime due to conflicting requirements - --> tests/compile/async_nonstatic_userdata.rs:11:72 - | -11 | methods.add_async_method("print", |_, data, ()| async move { - | ________________________________________________________________________^ -12 | | println!("{}", data.0); -13 | | Ok(()) -14 | | }); - | |_____________^ - | -note: first, the lifetime cannot outlive the lifetime `'a` as defined here... - --> tests/compile/async_nonstatic_userdata.rs:9:10 - | -9 | impl<'a> UserData for MyUserData<'a> { - | ^^ -note: ...so that the types are compatible - --> tests/compile/async_nonstatic_userdata.rs:11:72 - | -11 | methods.add_async_method("print", |_, data, ()| async move { - | ________________________________________________________________________^ -12 | | println!("{}", data.0); -13 | | Ok(()) -14 | | }); - | |_____________^ - = note: expected `(MyUserData<'_>,)` - found `(MyUserData<'a>,)` -note: but, the lifetime must be valid for the lifetime `'lua` as defined here... - --> tests/compile/async_nonstatic_userdata.rs:10:24 - | -10 | fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - | ^^^^ -note: ...so that the type `impl Future` will meet its required lifetime bounds... - --> tests/compile/async_nonstatic_userdata.rs:11:21 - | -11 | methods.add_async_method("print", |_, data, ()| async move { - | ^^^^^^^^^^^^^^^^ -note: ...that is required by this bound - --> src/userdata.rs - | - | MR: 'lua + Future>; - | ^^^^ +error: lifetime may not live long enough + --> tests/compile/async_nonstatic_userdata.rs:9:13 + | + 7 | impl UserData for MyUserData<'_> { + | -- lifetime `'1` appears in the `impl`'s self type + 8 | fn add_methods>(methods: &mut M) { + 9 | / methods.add_async_method("print", |_, data, ()| async move { +10 | | println!("{}", data.0); +11 | | Ok(()) +12 | | }); + | |______________^ requires that `'1` must outlive `'static` diff --git a/tests/compile/function_borrow.rs b/tests/compile/function_borrow.rs index f64f3b8f..2c532361 100644 --- a/tests/compile/function_borrow.rs +++ b/tests/compile/function_borrow.rs @@ -6,7 +6,5 @@ fn main() { let test = Test(0); let lua = Lua::new(); - let _ = lua.create_function(|_, ()| -> Result { - Ok(test.0) - }); + let _ = lua.create_function(|_, ()| -> Result { Ok(test.0) }); } diff --git a/tests/compile/function_borrow.stderr b/tests/compile/function_borrow.stderr index e99c2875..5bc66d65 100644 --- a/tests/compile/function_borrow.stderr +++ b/tests/compile/function_borrow.stderr @@ -1,20 +1,17 @@ -error[E0373]: closure may outlive the current function, but it borrows `test`, which is owned by the current function - --> tests/compile/function_borrow.rs:9:33 - | -9 | let _ = lua.create_function(|_, ()| -> Result { - | ^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `test` -10 | Ok(test.0) - | ------ `test` is borrowed here - | +error[E0373]: closure may outlive the current function, but it borrows `test.0`, which is owned by the current function + --> tests/compile/function_borrow.rs:9:33 + | +9 | let _ = lua.create_function(|_, ()| -> Result { Ok(test.0) }); + | ^^^^^^^^^^^^^^^^^^^^^^ ------ `test.0` is borrowed here + | | + | may outlive borrowed value `test.0` + | note: function requires argument type to outlive `'static` - --> tests/compile/function_borrow.rs:9:13 - | -9 | let _ = lua.create_function(|_, ()| -> Result { - | _____________^ -10 | | Ok(test.0) -11 | | }); - | |______^ -help: to force the closure to take ownership of `test` (and any other referenced variables), use the `move` keyword - | -9 | let _ = lua.create_function(move |_, ()| -> Result { - | ++++ + --> tests/compile/function_borrow.rs:9:13 + | +9 | let _ = lua.create_function(|_, ()| -> Result { Ok(test.0) }); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +help: to force the closure to take ownership of `test.0` (and any other referenced variables), use the `move` keyword + | +9 | let _ = lua.create_function(move |_, ()| -> Result { Ok(test.0) }); + | ++++ diff --git a/tests/compile/lua_norefunwindsafe.stderr b/tests/compile/lua_norefunwindsafe.stderr index a441ec87..4094cd2b 100644 --- a/tests/compile/lua_norefunwindsafe.stderr +++ b/tests/compile/lua_norefunwindsafe.stderr @@ -1,18 +1,101 @@ -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary - --> tests/compile/lua_norefunwindsafe.rs:7:5 - | -7 | catch_unwind(|| lua.create_table().unwrap()); - | ^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary - | - = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `PhantomData>>` - = note: required because it appears within the type `Arc>` - = note: required because it appears within the type `Lua` - = note: required because of the requirements on the impl of `UnwindSafe` for `&Lua` - = note: required because it appears within the type `[closure@$DIR/tests/compile/lua_norefunwindsafe.rs:7:18: 7:48]` -note: required by a bound in `catch_unwind` - --> $RUST/std/src/panic.rs - | - | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { - | ^^^^^^^^^^ required by this bound in `catch_unwind` +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + --> tests/compile/lua_norefunwindsafe.rs:7:18 + | +7 | catch_unwind(|| lua.create_table().unwrap()); + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + | | + | required by a bound introduced by this call + | + = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` +note: required because it appears within the type `lock_api::remutex::ReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct ReentrantMutex { + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `alloc::sync::ArcInner>` + --> $RUST/alloc/src/sync.rs + | + | struct ArcInner { + | ^^^^^^^^ +note: required because it appears within the type `PhantomData>>` + --> $RUST/core/src/marker.rs + | + | pub struct PhantomData; + | ^^^^^^^^^^^ +note: required because it appears within the type `Arc>` + --> $RUST/alloc/src/sync.rs + | + | pub struct Arc< + | ^^^ +note: required because it appears within the type `Lua` + --> src/state.rs + | + | pub struct Lua { + | ^^^ + = note: required for `&Lua` to implement `UnwindSafe` +note: required because it's used within this closure + --> tests/compile/lua_norefunwindsafe.rs:7:18 + | +7 | catch_unwind(|| lua.create_table().unwrap()); + | ^^ +note: required by a bound in `std::panic::catch_unwind` + --> $RUST/std/src/panic.rs + | + | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { + | ^^^^^^^^^^ required by this bound in `catch_unwind` + +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + --> tests/compile/lua_norefunwindsafe.rs:7:18 + | +7 | catch_unwind(|| lua.create_table().unwrap()); + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + | | + | required by a bound introduced by this call + | + = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` +note: required because it appears within the type `Cell` + --> $RUST/core/src/cell.rs + | + | pub struct Cell { + | ^^^^ +note: required because it appears within the type `lock_api::remutex::RawReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct RawReentrantMutex { + | ^^^^^^^^^^^^^^^^^ +note: required because it appears within the type `lock_api::remutex::ReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct ReentrantMutex { + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `alloc::sync::ArcInner>` + --> $RUST/alloc/src/sync.rs + | + | struct ArcInner { + | ^^^^^^^^ +note: required because it appears within the type `PhantomData>>` + --> $RUST/core/src/marker.rs + | + | pub struct PhantomData; + | ^^^^^^^^^^^ +note: required because it appears within the type `Arc>` + --> $RUST/alloc/src/sync.rs + | + | pub struct Arc< + | ^^^ +note: required because it appears within the type `Lua` + --> src/state.rs + | + | pub struct Lua { + | ^^^ + = note: required for `&Lua` to implement `UnwindSafe` +note: required because it's used within this closure + --> tests/compile/lua_norefunwindsafe.rs:7:18 + | +7 | catch_unwind(|| lua.create_table().unwrap()); + | ^^ +note: required by a bound in `std::panic::catch_unwind` + --> $RUST/std/src/panic.rs + | + | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { + | ^^^^^^^^^^ required by this bound in `catch_unwind` diff --git a/tests/compile/non_send.rs b/tests/compile/non_send.rs index fe030a5c..6f6b1e99 100644 --- a/tests/compile/non_send.rs +++ b/tests/compile/non_send.rs @@ -8,10 +8,8 @@ fn main() -> Result<()> { let data = Rc::new(Cell::new(0)); - lua.create_function(move |_, ()| { - Ok(data.get()) - })? - .call::<_, i32>(())?; + lua.create_function(move |_, ()| Ok(data.get()))? + .call::(())?; Ok(()) } diff --git a/tests/compile/non_send.stderr b/tests/compile/non_send.stderr index ecec24cc..c94b720f 100644 --- a/tests/compile/non_send.stderr +++ b/tests/compile/non_send.stderr @@ -1,19 +1,25 @@ error[E0277]: `Rc>` cannot be sent between threads safely - --> tests/compile/non_send.rs:11:9 - | -11 | lua.create_function(move |_, ()| { - | _________^^^^^^^^^^^^^^^_- - | | | - | | `Rc>` cannot be sent between threads safely -12 | | Ok(data.get()) -13 | | })? - | |_____- within this `[closure@$DIR/tests/compile/non_send.rs:11:25: 13:6]` - | - = help: within `[closure@$DIR/tests/compile/non_send.rs:11:25: 13:6]`, the trait `Send` is not implemented for `Rc>` - = note: required because it appears within the type `[closure@$DIR/tests/compile/non_send.rs:11:25: 13:6]` - = note: required because of the requirements on the impl of `mlua::types::MaybeSend` for `[closure@$DIR/tests/compile/non_send.rs:11:25: 13:6]` + --> tests/compile/non_send.rs:11:25 + | +11 | lua.create_function(move |_, ()| Ok(data.get()))? + | --------------- ------------^^^^^^^^^^^^^^^ + | | | + | | `Rc>` cannot be sent between threads safely + | | within this `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}` + | required by a bound introduced by this call + | + = help: within `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}`, the trait `Send` is not implemented for `Rc>` +note: required because it's used within this closure + --> tests/compile/non_send.rs:11:25 + | +11 | lua.create_function(move |_, ()| Ok(data.get()))? + | ^^^^^^^^^^^^ + = note: required for `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}` to implement `MaybeSend` note: required by a bound in `Lua::create_function` - --> src/lua.rs - | - | F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, - | ^^^^^^^^^ required by this bound in `Lua::create_function` + --> src/state.rs + | + | pub fn create_function(&self, func: F) -> Result + | --------------- required by a bound in this associated function + | where + | F: Fn(&Lua, A) -> Result + MaybeSend + 'static, + | ^^^^^^^^^ required by this bound in `Lua::create_function` diff --git a/tests/compile/ref_nounwindsafe.stderr b/tests/compile/ref_nounwindsafe.stderr index 8d3704d2..757083df 100644 --- a/tests/compile/ref_nounwindsafe.stderr +++ b/tests/compile/ref_nounwindsafe.stderr @@ -1,20 +1,111 @@ -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary - --> tests/compile/ref_nounwindsafe.rs:8:5 - | -8 | catch_unwind(move || table.set("a", "b").unwrap()); - | ^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary - | - = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` - = note: required because it appears within the type `alloc::sync::ArcInner>` - = note: required because it appears within the type `PhantomData>>` - = note: required because it appears within the type `Arc>` - = note: required because it appears within the type `Lua` - = note: required because of the requirements on the impl of `UnwindSafe` for `&Lua` - = note: required because it appears within the type `mlua::types::LuaRef<'_>` - = note: required because it appears within the type `LuaTable<'_>` - = note: required because it appears within the type `[closure@$DIR/tests/compile/ref_nounwindsafe.rs:8:18: 8:54]` -note: required by a bound in `catch_unwind` - --> $RUST/std/src/panic.rs - | - | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { - | ^^^^^^^^^^ required by this bound in `catch_unwind` +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + | | + | required by a bound introduced by this call + | + = help: within `alloc::sync::ArcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` +note: required because it appears within the type `lock_api::remutex::ReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct ReentrantMutex { + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `alloc::sync::ArcInner>` + --> $RUST/alloc/src/sync.rs + | + | struct ArcInner { + | ^^^^^^^^ + = note: required for `NonNull>>` to implement `UnwindSafe` +note: required because it appears within the type `std::sync::Weak>` + --> $RUST/alloc/src/sync.rs + | + | pub struct Weak< + | ^^^^ +note: required because it appears within the type `WeakLua` + --> src/state.rs + | + | pub struct WeakLua(XWeak>); + | ^^^^^^^ +note: required because it appears within the type `mlua::types::value_ref::ValueRef` + --> src/types/value_ref.rs + | + | pub struct ValueRef { + | ^^^^^^^^ +note: required because it appears within the type `LuaTable` + --> src/table.rs + | + | pub struct Table(pub(crate) ValueRef); + | ^^^^^ +note: required because it's used within this closure + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ^^^^^^^ +note: required by a bound in `std::panic::catch_unwind` + --> $RUST/std/src/panic.rs + | + | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { + | ^^^^^^^^^^ required by this bound in `catch_unwind` + +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferable across a catch_unwind boundary + | | + | required by a bound introduced by this call + | + = help: within `alloc::sync::ArcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` +note: required because it appears within the type `Cell` + --> $RUST/core/src/cell.rs + | + | pub struct Cell { + | ^^^^ +note: required because it appears within the type `lock_api::remutex::RawReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct RawReentrantMutex { + | ^^^^^^^^^^^^^^^^^ +note: required because it appears within the type `lock_api::remutex::ReentrantMutex` + --> $CARGO/lock_api-$VERSION/src/remutex.rs + | + | pub struct ReentrantMutex { + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `alloc::sync::ArcInner>` + --> $RUST/alloc/src/sync.rs + | + | struct ArcInner { + | ^^^^^^^^ + = note: required for `NonNull>>` to implement `UnwindSafe` +note: required because it appears within the type `std::sync::Weak>` + --> $RUST/alloc/src/sync.rs + | + | pub struct Weak< + | ^^^^ +note: required because it appears within the type `WeakLua` + --> src/state.rs + | + | pub struct WeakLua(XWeak>); + | ^^^^^^^ +note: required because it appears within the type `mlua::types::value_ref::ValueRef` + --> src/types/value_ref.rs + | + | pub struct ValueRef { + | ^^^^^^^^ +note: required because it appears within the type `LuaTable` + --> src/table.rs + | + | pub struct Table(pub(crate) ValueRef); + | ^^^^^ +note: required because it's used within this closure + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ^^^^^^^ +note: required by a bound in `std::panic::catch_unwind` + --> $RUST/std/src/panic.rs + | + | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { + | ^^^^^^^^^^ required by this bound in `catch_unwind` diff --git a/tests/compile/scope_callback_capture.rs b/tests/compile/scope_callback_capture.rs index 927c36d7..9f8ec53c 100644 --- a/tests/compile/scope_callback_capture.rs +++ b/tests/compile/scope_callback_capture.rs @@ -4,15 +4,11 @@ fn main() { let lua = Lua::new(); lua.scope(|scope| { let mut inner: Option
= None; - let f = scope - .create_function_mut(move |_, t: Table| { - if let Some(old) = inner.take() { - // Access old callback `Lua`. - } - inner = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; + let f = scope.create_function_mut(|_, t: Table| { + inner = Some(t); + Ok(()) + })?; + f.call::<()>(lua.create_table()?)?; Ok(()) }); } diff --git a/tests/compile/scope_callback_capture.stderr b/tests/compile/scope_callback_capture.stderr index 5c8914b3..7fa3a5e2 100644 --- a/tests/compile/scope_callback_capture.stderr +++ b/tests/compile/scope_callback_capture.stderr @@ -1,26 +1,24 @@ -warning: unused variable: `old` - --> $DIR/scope_callback_capture.rs:9:29 - | -9 | if let Some(old) = inner.take() { - | ^^^ help: if this is intentional, prefix it with an underscore: `_old` - | - = note: `#[warn(unused_variables)]` on by default - -error[E0521]: borrowed data escapes outside of closure - --> $DIR/scope_callback_capture.rs:7:17 +error[E0373]: closure may outlive the current function, but it borrows `inner`, which is owned by the current function + --> tests/compile/scope_callback_capture.rs:7:43 | -5 | lua.scope(|scope| { - | ----- - | | - | `scope` declared here, outside of the closure body - | `scope` is a reference that is only valid in the closure body -6 | let mut inner: Option
= None; -7 | let f = scope + 5 | lua.scope(|scope| { + | ----- has type `&'1 mlua::Scope<'1, '_>` + 6 | let mut inner: Option
= None; + 7 | let f = scope.create_function_mut(|_, t: Table| { + | ^^^^^^^^^^^^^ may outlive borrowed value `inner` + 8 | inner = Some(t); + | ----- `inner` is borrowed here + | +note: function requires argument type to outlive `'1` + --> tests/compile/scope_callback_capture.rs:7:17 + | + 7 | let f = scope.create_function_mut(|_, t: Table| { | _________________^ -8 | | .create_function_mut(move |_, t: Table| { -9 | | if let Some(old) = inner.take() { -10 | | // Access old callback `Lua`. -... | -13 | | Ok(()) -14 | | })?; - | |______________^ `scope` escapes the closure body here + 8 | | inner = Some(t); + 9 | | Ok(()) +10 | | })?; + | |__________^ +help: to force the closure to take ownership of `inner` (and any other referenced variables), use the `move` keyword + | + 7 | let f = scope.create_function_mut(move |_, t: Table| { + | ++++ diff --git a/tests/compile/scope_callback_inner.rs b/tests/compile/scope_callback_inner.rs deleted file mode 100644 index 037c6acb..00000000 --- a/tests/compile/scope_callback_inner.rs +++ /dev/null @@ -1,15 +0,0 @@ -use mlua::{Lua, Table}; - -fn main() { - let lua = Lua::new(); - lua.scope(|scope| { - let mut inner: Option
= None; - let f = scope - .create_function_mut(|_, t: Table| { - inner = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; - Ok(()) - }); -} diff --git a/tests/compile/scope_callback_inner.stderr b/tests/compile/scope_callback_inner.stderr deleted file mode 100644 index 55c38b94..00000000 --- a/tests/compile/scope_callback_inner.stderr +++ /dev/null @@ -1,42 +0,0 @@ -error[E0521]: borrowed data escapes outside of closure - --> tests/compile/scope_callback_inner.rs:7:17 - | -5 | lua.scope(|scope| { - | ----- - | | - | `scope` declared here, outside of the closure body - | `scope` is a reference that is only valid in the closure body -6 | let mut inner: Option
= None; -7 | let f = scope - | _________________^ -8 | | .create_function_mut(|_, t: Table| { -9 | | inner = Some(t); -10 | | Ok(()) -11 | | })?; - | |______________^ `scope` escapes the closure body here - -error[E0373]: closure may outlive the current function, but it borrows `inner`, which is owned by the current function - --> tests/compile/scope_callback_inner.rs:8:34 - | -5 | lua.scope(|scope| { - | ----- has type `&mlua::Scope<'_, '2>` -... -8 | .create_function_mut(|_, t: Table| { - | ^^^^^^^^^^^^^ may outlive borrowed value `inner` -9 | inner = Some(t); - | ----- `inner` is borrowed here - | -note: function requires argument type to outlive `'2` - --> tests/compile/scope_callback_inner.rs:7:17 - | -7 | let f = scope - | _________________^ -8 | | .create_function_mut(|_, t: Table| { -9 | | inner = Some(t); -10 | | Ok(()) -11 | | })?; - | |______________^ -help: to force the closure to take ownership of `inner` (and any other referenced variables), use the `move` keyword - | -8 | .create_function_mut(move |_, t: Table| { - | ++++ diff --git a/tests/compile/scope_callback_outer.rs b/tests/compile/scope_callback_outer.rs deleted file mode 100644 index 7c9974ea..00000000 --- a/tests/compile/scope_callback_outer.rs +++ /dev/null @@ -1,15 +0,0 @@ -use mlua::{Lua, Table}; - -fn main() { - let lua = Lua::new(); - let mut outer: Option
= None; - lua.scope(|scope| { - let f = scope - .create_function_mut(|_, t: Table| { - outer = Some(t); - Ok(()) - })?; - f.call::<_, ()>(lua.create_table()?)?; - Ok(()) - }); -} diff --git a/tests/compile/scope_callback_outer.stderr b/tests/compile/scope_callback_outer.stderr deleted file mode 100644 index 2a15e4a6..00000000 --- a/tests/compile/scope_callback_outer.stderr +++ /dev/null @@ -1,30 +0,0 @@ -error[E0521]: borrowed data escapes outside of closure - --> $DIR/scope_callback_outer.rs:7:17 - | -6 | lua.scope(|scope| { - | ----- - | | - | `scope` declared here, outside of the closure body - | `scope` is a reference that is only valid in the closure body -7 | let f = scope - | _________________^ -8 | | .create_function_mut(|_, t: Table| { -9 | | outer = Some(t); -10 | | Ok(()) -11 | | })?; - | |______________^ `scope` escapes the closure body here - -error[E0597]: `outer` does not live long enough - --> $DIR/scope_callback_outer.rs:9:17 - | -6 | lua.scope(|scope| { - | ------- value captured here -... -9 | outer = Some(t); - | ^^^^^ borrowed value does not live long enough -... -15 | } - | - - | | - | `outer` dropped here while still borrowed - | borrow might be used here, when `outer` is dropped and runs the destructor for type `Option>` diff --git a/tests/compile/scope_invariance.rs b/tests/compile/scope_invariance.rs index e4f4ea75..0414efc8 100644 --- a/tests/compile/scope_invariance.rs +++ b/tests/compile/scope_invariance.rs @@ -10,14 +10,13 @@ fn main() { let f = { let mut test = Test { field: 0 }; - scope - .create_function_mut(|_, ()| { - test.field = 42; - //~^ error: `test` does not live long enough - Ok(()) - })? + scope.create_function_mut(|_, ()| { + test.field = 42; + //~^ error: `test` does not live long enough + Ok(()) + })? }; - f.call::<_, ()>(()) + f.call::<()>(()) }); } diff --git a/tests/compile/scope_invariance.stderr b/tests/compile/scope_invariance.stderr index 9c91f6a6..8bad0c12 100644 --- a/tests/compile/scope_invariance.stderr +++ b/tests/compile/scope_invariance.stderr @@ -1,25 +1,24 @@ -error[E0373]: closure may outlive the current function, but it borrows `test`, which is owned by the current function - --> tests/compile/scope_invariance.rs:14:38 +error[E0373]: closure may outlive the current function, but it borrows `test.field`, which is owned by the current function + --> tests/compile/scope_invariance.rs:13:39 | -9 | lua.scope(|scope| { - | ----- has type `&mlua::Scope<'_, '1>` + 9 | lua.scope(|scope| { + | ----- has type `&'1 mlua::Scope<'1, '_>` ... -14 | .create_function_mut(|_, ()| { - | ^^^^^^^ may outlive borrowed value `test` -15 | test.field = 42; - | ---------- `test` is borrowed here +13 | scope.create_function_mut(|_, ()| { + | ^^^^^^^ may outlive borrowed value `test.field` +14 | test.field = 42; + | ---------- `test.field` is borrowed here | note: function requires argument type to outlive `'1` --> tests/compile/scope_invariance.rs:13:13 | -13 | / scope -14 | | .create_function_mut(|_, ()| { -15 | | test.field = 42; -16 | | //~^ error: `test` does not live long enough -17 | | Ok(()) -18 | | })? - | |__________________^ -help: to force the closure to take ownership of `test` (and any other referenced variables), use the `move` keyword +13 | / scope.create_function_mut(|_, ()| { +14 | | test.field = 42; +15 | | //~^ error: `test` does not live long enough +16 | | Ok(()) +17 | | })? + | |______________^ +help: to force the closure to take ownership of `test.field` (and any other referenced variables), use the `move` keyword | -14 | .create_function_mut(move |_, ()| { - | ++++ +13 | scope.create_function_mut(move |_, ()| { + | ++++ diff --git a/tests/compile/scope_mutable_aliasing.rs b/tests/compile/scope_mutable_aliasing.rs index 4e1dcf9e..4745296b 100644 --- a/tests/compile/scope_mutable_aliasing.rs +++ b/tests/compile/scope_mutable_aliasing.rs @@ -2,14 +2,14 @@ use mlua::{Lua, UserData}; fn main() { struct MyUserData<'a>(&'a mut i32); - impl<'a> UserData for MyUserData<'a> {} + impl UserData for MyUserData<'_> {} let mut i = 1; let lua = Lua::new(); lua.scope(|scope| { - let _a = scope.create_nonstatic_userdata(MyUserData(&mut i)).unwrap(); - let _b = scope.create_nonstatic_userdata(MyUserData(&mut i)).unwrap(); + let _a = scope.create_userdata(MyUserData(&mut i)).unwrap(); + let _b = scope.create_userdata(MyUserData(&mut i)).unwrap(); Ok(()) }); } diff --git a/tests/compile/scope_mutable_aliasing.stderr b/tests/compile/scope_mutable_aliasing.stderr index c661826f..d7724660 100644 --- a/tests/compile/scope_mutable_aliasing.stderr +++ b/tests/compile/scope_mutable_aliasing.stderr @@ -1,9 +1,18 @@ error[E0499]: cannot borrow `i` as mutable more than once at a time - --> $DIR/scope_mutable_aliasing.rs:12:61 + --> tests/compile/scope_mutable_aliasing.rs:12:51 | -11 | let _a = scope.create_nonstatic_userdata(MyUserData(&mut i)).unwrap(); - | ------ first mutable borrow occurs here -12 | let _b = scope.create_nonstatic_userdata(MyUserData(&mut i)).unwrap(); - | ------------------------- ^^^^^^ second mutable borrow occurs here - | | - | first borrow later used by call +10 | lua.scope(|scope| { + | ----- has type `&mlua::Scope<'_, '1>` +11 | let _a = scope.create_userdata(MyUserData(&mut i)).unwrap(); + | ----------------------------------------- + | | | + | | first mutable borrow occurs here + | argument requires that `i` is borrowed for `'1` +12 | let _b = scope.create_userdata(MyUserData(&mut i)).unwrap(); + | ^^^^^^ second mutable borrow occurs here + | +note: requirement that the value outlives `'1` introduced here + --> src/scope.rs + | + | T: UserData + 'env, + | ^^^^ diff --git a/tests/compile/scope_userdata_borrow.rs b/tests/compile/scope_userdata_borrow.rs index 24652344..53b1b813 100644 --- a/tests/compile/scope_userdata_borrow.rs +++ b/tests/compile/scope_userdata_borrow.rs @@ -3,16 +3,16 @@ use mlua::{Lua, UserData}; fn main() { // Should not allow userdata borrow to outlive lifetime of AnyUserData handle struct MyUserData<'a>(&'a i32); - impl<'a> UserData for MyUserData<'a> {} + impl UserData for MyUserData<'_> {} let igood = 1; let lua = Lua::new(); lua.scope(|scope| { - let _ugood = scope.create_nonstatic_userdata(MyUserData(&igood)).unwrap(); + let _ugood = scope.create_userdata(MyUserData(&igood)).unwrap(); let _ubad = { let ibad = 42; - scope.create_nonstatic_userdata(MyUserData(&ibad)).unwrap(); + scope.create_userdata(MyUserData(&ibad)).unwrap(); }; Ok(()) }); diff --git a/tests/compile/scope_userdata_borrow.stderr b/tests/compile/scope_userdata_borrow.stderr index 9d898049..7aa771f2 100644 --- a/tests/compile/scope_userdata_borrow.stderr +++ b/tests/compile/scope_userdata_borrow.stderr @@ -1,13 +1,21 @@ error[E0597]: `ibad` does not live long enough - --> tests/compile/scope_userdata_borrow.rs:15:56 + --> tests/compile/scope_userdata_borrow.rs:15:46 | 11 | lua.scope(|scope| { | ----- has type `&mlua::Scope<'_, '1>` ... -15 | scope.create_nonstatic_userdata(MyUserData(&ibad)).unwrap(); - | -------------------------------------------^^^^^-- - | | | - | | borrowed value does not live long enough +14 | let ibad = 42; + | ---- binding `ibad` declared here +15 | scope.create_userdata(MyUserData(&ibad)).unwrap(); + | ---------------------------------^^^^^-- + | | | + | | borrowed value does not live long enough | argument requires that `ibad` is borrowed for `'1` 16 | }; | - `ibad` dropped here while still borrowed + | +note: requirement that the value outlives `'1` introduced here + --> src/scope.rs + | + | T: UserData + 'env, + | ^^^^ diff --git a/tests/compile/static_callback_args.rs b/tests/compile/static_callback_args.rs deleted file mode 100644 index 66dbf8b2..00000000 --- a/tests/compile/static_callback_args.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::cell::RefCell; - -use mlua::{Lua, Result, Table}; - -fn main() -> Result<()> { - thread_local! { - static BAD_TIME: RefCell>> = RefCell::new(None); - } - - let lua = Lua::new(); - - lua.create_function(|_, table: Table| { - BAD_TIME.with(|bt| { - *bt.borrow_mut() = Some(table); - }); - Ok(()) - })? - .call::<_, ()>(lua.create_table()?)?; - - // In debug, this will panic with a reference leak before getting to the next part but - // it segfaults anyway. - drop(lua); - - BAD_TIME.with(|bt| { - println!( - "you're gonna have a bad time: {}", - bt.borrow().as_ref().unwrap().len().unwrap() - ); - }); - - Ok(()) -} diff --git a/tests/compile/static_callback_args.stderr b/tests/compile/static_callback_args.stderr deleted file mode 100644 index 0683dee9..00000000 --- a/tests/compile/static_callback_args.stderr +++ /dev/null @@ -1,33 +0,0 @@ -error[E0597]: `lua` does not live long enough - --> tests/compile/static_callback_args.rs:12:5 - | -12 | / lua.create_function(|_, table: Table| { -13 | | BAD_TIME.with(|bt| { -14 | | *bt.borrow_mut() = Some(table); -15 | | }); -16 | | Ok(()) -17 | | })? - | | ^ - | | | - | |______borrowed value does not live long enough - | argument requires that `lua` is borrowed for `'static` -... -32 | } - | - `lua` dropped here while still borrowed - -error[E0505]: cannot move out of `lua` because it is borrowed - --> tests/compile/static_callback_args.rs:22:10 - | -12 | / lua.create_function(|_, table: Table| { -13 | | BAD_TIME.with(|bt| { -14 | | *bt.borrow_mut() = Some(table); -15 | | }); -16 | | Ok(()) -17 | | })? - | | - - | | | - | |______borrow of `lua` occurs here - | argument requires that `lua` is borrowed for `'static` -... -22 | drop(lua); - | ^^^ move out of `lua` occurs here diff --git a/tests/compile/userdata_borrow.rs b/tests/compile/userdata_borrow.rs deleted file mode 100644 index 26eb3c77..00000000 --- a/tests/compile/userdata_borrow.rs +++ /dev/null @@ -1,19 +0,0 @@ -use mlua::{AnyUserData, Lua, Table, UserData, Result}; - -fn main() -> Result<()> { - let lua = Lua::new(); - let globals = lua.globals(); - - // Should not allow userdata borrow to outlive lifetime of AnyUserData handle - struct MyUserData; - impl UserData for MyUserData {}; - let _userdata_ref; - { - let touter = globals.get::<_, Table>("touter")?; - touter.set("userdata", lua.create_userdata(MyUserData)?)?; - let userdata = touter.get::<_, AnyUserData>("userdata")?; - _userdata_ref = userdata.borrow::(); - //~^ error: `userdata` does not live long enough - } - Ok(()) -} diff --git a/tests/compile/userdata_borrow.stderr b/tests/compile/userdata_borrow.stderr deleted file mode 100644 index 7ac96703..00000000 --- a/tests/compile/userdata_borrow.stderr +++ /dev/null @@ -1,13 +0,0 @@ -error[E0597]: `userdata` does not live long enough - --> $DIR/userdata_borrow.rs:15:25 - | -15 | _userdata_ref = userdata.borrow::(); - | ^^^^^^^^ borrowed value does not live long enough -16 | //~^ error: `userdata` does not live long enough -17 | } - | - `userdata` dropped here while still borrowed -18 | Ok(()) -19 | } - | - borrow might be used here, when `_userdata_ref` is dropped and runs the destructor for type `std::result::Result, mlua::error::Error>` - | - = note: values in a scope are dropped in the opposite order they are defined diff --git a/tests/conversion.rs b/tests/conversion.rs index 6d17d0ad..ca16327e 100644 --- a/tests/conversion.rs +++ b/tests/conversion.rs @@ -1,9 +1,379 @@ use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use std::ffi::{CStr, CString}; +use std::ffi::{CString, OsString}; +use std::path::PathBuf; +use bstr::BString; use maplit::{btreemap, btreeset, hashmap, hashset}; -use mlua::{Error, Lua, Result}; +use mlua::{ + AnyUserData, BorrowedBytes, BorrowedStr, Either, Error, Function, IntoLua, Lua, RegistryKey, Result, + Table, Thread, UserDataRef, Value, +}; + +#[test] +fn test_value_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let v = Value::Boolean(true); + let v2 = (&v).into_lua(&lua)?; + assert_eq!(v, v2); + + // Push into stack + let table = lua.create_table()?; + table.set("v", &v)?; + assert_eq!(v, table.get::("v")?); + + Ok(()) +} + +#[test] +fn test_string_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let s2 = (&s).into_lua(&lua)?; + assert_eq!(s, *s2.as_string().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("s", &s)?; + assert_eq!(s, table.get::("s")?); + + Ok(()) +} + +#[test] +fn test_string_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, s: mlua::LuaString| Ok(s))?; + let s = f.call::("hello, world!")?; + assert_eq!(s, "hello, world!"); + + // Should fallback to default conversion + let s = f.call::(42)?; + assert_eq!(s, "42"); + + Ok(()) +} + +#[test] +fn test_borrowedstr_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let bs = s.to_str()?; + let bs2 = (&bs).into_lua(&lua)?; + assert_eq!(bs2.as_string().unwrap(), "hello, world!"); + + // Push into stack + let table = lua.create_table()?; + table.set("bs", &bs)?; + assert_eq!(bs, table.get::("bs")?); + + Ok(()) +} + +#[test] +fn test_borrowedstr_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, s: BorrowedStr| Ok(s))?; + let s = f.call::("hello, world!")?; + assert_eq!(s, "hello, world!"); + + Ok(()) +} + +#[test] +fn test_borrowedbytes_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let bb = s.as_bytes(); + let bb2 = (&bb).into_lua(&lua)?; + assert_eq!(bb2.as_string().unwrap(), "hello, world!"); + + // Push into stack + let table = lua.create_table()?; + table.set("bb", &bb)?; + assert_eq!(bb, table.get::("bb")?.as_bytes()); + + Ok(()) +} + +#[test] +fn test_borrowedbytes_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, s: BorrowedBytes| Ok(s))?; + let s = f.call::("hello, world!")?; + assert_eq!(s, "hello, world!"); + + Ok(()) +} + +#[test] +fn test_table_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let t = lua.create_table()?; + let t2 = (&t).into_lua(&lua)?; + assert_eq!(&t, t2.as_table().unwrap()); + + // Push into stack + let f = lua.create_function(|_, (t, s): (Table, String)| t.set("s", s))?; + f.call::<()>((&t, "hello"))?; + assert_eq!("hello", t.get::("s")?); + + Ok(()) +} + +#[test] +fn test_function_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?; + let f2 = (&f).into_lua(&lua)?; + assert_eq!(&f, f2.as_function().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("f", &f)?; + assert_eq!(f, table.get::("f")?); + + Ok(()) +} + +#[test] +fn test_function_from_lua() -> Result<()> { + let lua = Lua::new(); + + assert!(lua.globals().get::("print").is_ok()); + match lua.globals().get::("math") { + Err(err @ Error::FromLuaConversionError { .. }) => { + assert_eq!(err.to_string(), "error converting Lua table to function"); + } + _ => panic!("expected `Error::FromLuaConversionError`"), + } + + Ok(()) +} + +#[test] +fn test_thread_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let f = lua.create_function(|_, ()| Ok::<_, Error>(()))?; + let th = lua.create_thread(f)?; + let th2 = (&th).into_lua(&lua)?; + assert_eq!(&th, th2.as_thread().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("th", &th)?; + assert_eq!(th, table.get::("th")?); + + Ok(()) +} + +#[test] +fn test_thread_from_lua() -> Result<()> { + let lua = Lua::new(); + + match lua.globals().get::("print") { + Err(err @ Error::FromLuaConversionError { .. }) => { + assert_eq!(err.to_string(), "error converting Lua function to thread"); + } + _ => panic!("expected `Error::FromLuaConversionError`"), + } + + Ok(()) +} + +#[test] +fn test_anyuserdata_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let ud = lua.create_any_userdata(String::from("hello"))?; + let ud2 = (&ud).into_lua(&lua)?; + assert_eq!(&ud, ud2.as_userdata().unwrap()); + + // Push into stack + let table = lua.create_table()?; + table.set("ud", &ud)?; + assert_eq!(ud, table.get::("ud")?); + assert_eq!("hello", *table.get::>("ud")?); + + Ok(()) +} + +#[test] +fn test_anyuserdata_from_lua() -> Result<()> { + let lua = Lua::new(); + + match lua.globals().get::("print") { + Err(err @ Error::FromLuaConversionError { .. }) => { + assert_eq!(err.to_string(), "error converting Lua function to userdata"); + } + _ => panic!("expected `Error::FromLuaConversionError`"), + } + + Ok(()) +} + +#[test] +fn test_error_conversion() -> Result<()> { + let lua = Lua::new(); + + // Any Lua value can be converted to `Error` + match lua.convert::(Error::external("external error")) { + Ok(Error::ExternalError(msg)) => assert_eq!(msg.to_string(), "external error"), + res => panic!("expected `Error::ExternalError`, got {res:?}"), + } + match lua.convert::("abc") { + Ok(Error::RuntimeError(msg)) => assert_eq!(msg, "abc"), + res => panic!("expected `Error::RuntimeError`, got {res:?}"), + } + match lua.convert::(true) { + Ok(Error::RuntimeError(msg)) => assert_eq!(msg, "true"), + res => panic!("expected `Error::RuntimeError`, got {res:?}"), + } + match lua.convert::(lua.globals()) { + Ok(Error::RuntimeError(msg)) => assert!(msg.starts_with("table:")), + res => panic!("expected `Error::RuntimeError`, got {res:?}"), + } + + Ok(()) +} + +#[test] +fn test_registry_value_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world")?; + let r = lua.create_registry_value(&s)?; + let value1 = lua.pack(&r)?; + let value2 = lua.pack(r)?; + assert_eq!(value1.to_string()?, "hello, world"); + assert_eq!(value1.to_pointer(), value2.to_pointer()); + + // Push into stack + let t = lua.create_table()?; + let r = lua.create_registry_value(&t)?; + let f = lua.create_function(|_, (t, k, v): (Table, Value, Value)| t.set(k, v))?; + f.call::<()>((&r, "hello", "world"))?; + f.call::<()>((r, "welcome", "to the jungle"))?; + assert_eq!(t.get::("hello")?, "world"); + assert_eq!(t.get::("welcome")?, "to the jungle"); + + // Try to set nil registry key + let r_nil = lua.create_registry_value(Value::Nil)?; + t.set("hello", &r_nil)?; + assert_eq!(t.get::("hello")?, Value::Nil); + + // Check non-owned registry key + let lua2 = Lua::new(); + let r2 = lua2.create_registry_value("abc")?; + assert!(matches!(f.call::<()>(&r2), Err(Error::MismatchedRegistryKey))); + + Ok(()) +} + +#[test] +fn test_registry_key_from_lua() -> Result<()> { + let lua = Lua::new(); + + let fkey = lua.load("function() return 1 end").eval::()?; + let f = lua.registry_value::(&fkey)?; + assert_eq!(f.call::(())?, 1); + + Ok(()) +} + +#[test] +fn test_bool_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + assert!(true.into_lua(&lua)?.is_boolean()); + + // Push into stack + let table = lua.create_table()?; + table.set("b", true)?; + assert_eq!(true, table.get::("b")?); + + Ok(()) +} + +#[test] +fn test_bool_from_lua() -> Result<()> { + let lua = Lua::new(); + + assert!(lua.globals().get::("print")?); + assert!(lua.convert::(123)?); + assert!(!lua.convert::(Value::Nil)?); + + Ok(()) +} + +#[test] +fn test_integer_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, i: i32| Ok(i))?; + assert_eq!(f.call::(42)?, 42); + + // Out of range + match f.call::(i64::MAX).err() { + Some(Error::CallbackError { cause, .. }) => match cause.as_ref() { + Error::BadArgument { cause, .. } => match cause.as_ref() { + Error::FromLuaConversionError { message, .. } => { + assert_eq!(message.as_ref().unwrap(), "out of range"); + } + err => panic!("expected Error::FromLuaConversionError, got {err:?}"), + }, + err => panic!("expected Error::BadArgument, got {err:?}"), + }, + err => panic!("expected Error::CallbackError, got {err:?}"), + } + + // Should fallback to default conversion + assert_eq!(f.call::("42")?, 42); + + Ok(()) +} + +#[test] +fn test_float_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, f: f32| Ok(f))?; + assert_eq!(f.call::(42.0)?, 42.0); + + // Out of range (but never fails) + let val = f.call::(f64::MAX)?; + assert!(val.is_infinite()); + + // Should fallback to default conversion + assert_eq!(f.call::("42.0")?, 42.0); + + Ok(()) +} #[test] fn test_conv_vec() -> Result<()> { @@ -80,8 +450,8 @@ fn test_conv_cstring() -> Result<()> { let s2: CString = lua.globals().get("s")?; assert_eq!(s, s2); - let cs = CStr::from_bytes_with_nul(b"hello\0").unwrap(); - lua.globals().set("cs", cs)?; + let cs = c"hello"; + lua.globals().set("cs", c"hello")?; let cs2: CString = lua.globals().get("cs")?; assert_eq!(cs, cs2.as_c_str()); @@ -133,8 +503,252 @@ fn test_conv_array() -> Result<()> { let v2: [i32; 3] = lua.globals().get("v")?; assert_eq!(v, v2); - let v2 = lua.globals().get::<_, [i32; 4]>("v"); + let v2 = lua.globals().get::<[i32; 4]>("v"); assert!(matches!(v2, Err(Error::FromLuaConversionError { .. }))); Ok(()) } + +#[test] +fn test_bstring_from_lua() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_string("hello, world")?; + let bstr = lua.unpack::(Value::String(s))?; + assert_eq!(bstr, "hello, world"); + + let bstr = lua.unpack::(Value::Integer(123))?; + assert_eq!(bstr, "123"); + + let bstr = lua.unpack::(Value::Number(-123.55))?; + assert_eq!(bstr, "-123.55"); + + // Test from stack + let f = lua.create_function(|_, bstr: BString| Ok(bstr))?; + let bstr = f.call::("hello, world")?; + assert_eq!(bstr, "hello, world"); + + let bstr = f.call::(-43.22)?; + assert_eq!(bstr, "-43.22"); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_bstring_from_lua_buffer() -> Result<()> { + let lua = Lua::new(); + + let buf = lua.create_buffer("hello, world")?; + let bstr = lua.convert::(buf)?; + assert_eq!(bstr, "hello, world"); + + // Test from stack + let f = lua.create_function(|_, bstr: BString| Ok(bstr))?; + let buf = lua.create_buffer("hello, world")?; + let bstr = f.call::(buf)?; + assert_eq!(bstr, "hello, world"); + + Ok(()) +} + +#[test] +fn test_osstring_into_from_lua() -> Result<()> { + let lua = Lua::new(); + + let s = OsString::from("hello, world"); + + let v = lua.pack(s.as_os_str())?; + assert!(v.is_string()); + assert_eq!(v.as_string().unwrap(), "hello, world"); + + let v = lua.pack(s)?; + assert!(v.is_string()); + assert_eq!(v.as_string().unwrap(), "hello, world"); + + let s = lua.create_string("hello, world")?; + let bstr = lua.unpack::(Value::String(s))?; + assert_eq!(bstr, "hello, world"); + + let bstr = lua.unpack::(Value::Integer(123))?; + assert_eq!(bstr, "123"); + + let bstr = lua.unpack::(Value::Number(-123.55))?; + assert_eq!(bstr, "-123.55"); + + Ok(()) +} + +#[test] +fn test_pathbuf_into_from_lua() -> Result<()> { + let lua = Lua::new(); + + let pb = PathBuf::from(env!("CARGO_TARGET_TMPDIR")); + let pb_str = pb.to_str().unwrap(); + + let v = lua.pack(pb.as_path())?; + assert!(v.is_string()); + assert_eq!(v.to_string().unwrap(), pb_str); + + let v = lua.pack(pb.clone())?; + assert!(v.is_string()); + assert_eq!(v.to_string().unwrap(), pb_str); + + let s = lua.create_string(pb_str)?; + let bstr = lua.unpack::(Value::String(s))?; + assert_eq!(bstr, pb); + + Ok(()) +} + +#[test] +fn test_option_into_from_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let v = Some(42); + let v2 = v.into_lua(&lua)?; + assert_eq!(v, v2.as_i32()); + + // Push into stack / get from stack + let f = lua.create_function(|_, v: Option| Ok(v))?; + assert_eq!(f.call::>(Some(42))?, Some(42)); + assert_eq!(f.call::>(Option::::None)?, None); + assert_eq!(f.call::>(())?, None); + + Ok(()) +} + +#[test] +fn test_either_enum() -> Result<()> { + // Left + let mut either = Either::<_, String>::Left(42); + assert!(either.is_left()); + assert_eq!(*either.as_ref().left().unwrap(), 42); + *either.as_mut().left().unwrap() = 44; + assert_eq!(*either.as_ref().left().unwrap(), 44); + assert_eq!(format!("{either}"), "44"); + assert_eq!(either.right(), None); + + // Right + either = Either::Right("hello".to_string()); + assert!(either.is_right()); + assert_eq!(*either.as_ref().right().unwrap(), "hello"); + *either.as_mut().right().unwrap() = "world".to_string(); + assert_eq!(*either.as_ref().right().unwrap(), "world"); + assert_eq!(format!("{either}"), "world"); + assert_eq!(either.left(), None); + + Ok(()) +} + +#[test] +fn test_either_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let mut either = Either::::Left(42); + assert_eq!(either.into_lua(&lua)?, Value::Integer(42)); + let t = lua.create_table()?; + either = Either::Right(&t); + assert!(matches!(either.into_lua(&lua)?, Value::Table(_))); + + // Push into stack + let f = + lua.create_function(|_, either: Either| either.right().unwrap().set("hello", "world"))?; + let t = lua.create_table()?; + either = Either::Right(&t); + f.call::<()>(either)?; + assert_eq!(t.get::("hello")?, "world"); + + let f = lua.create_function(|_, either: Either| Ok(either.left().unwrap() + 1))?; + either = Either::Left(42); + assert_eq!(f.call::(either)?, 43); + + Ok(()) +} + +#[test] +fn test_either_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From value + let mut either = lua.unpack::>(Value::Integer(42))?; + assert!(either.is_left()); + assert_eq!(*either.as_ref().left().unwrap(), 42); + let t = lua.create_table()?; + either = lua.unpack::>(Value::Table(t.clone()))?; + assert!(either.is_right()); + assert_eq!(either.as_ref().right().unwrap(), &t); + match lua.unpack::>(Value::String(lua.create_string("abc")?)) { + Err(Error::FromLuaConversionError { to, .. }) => assert_eq!(to, "Either"), + _ => panic!("expected `Error::FromLuaConversionError`"), + } + + // From stack + let f = lua.create_function(|_, either: Either| Ok(either))?; + let either = f.call::>(42)?; + assert!(either.is_left()); + assert_eq!(*either.as_ref().left().unwrap(), 42); + + let either = f.call::>([5; 5])?; + assert!(either.is_right()); + assert_eq!(either.as_ref().right().unwrap(), &[5; 5]); + + // Check error message + match f.call::("hello") { + Ok(_) => panic!("expected error, got Ok"), + Err(ref err @ Error::CallbackError { ref cause, .. }) => { + match cause.as_ref() { + Error::BadArgument { cause, .. } => match cause.as_ref() { + Error::FromLuaConversionError { to, .. } => { + assert_eq!(to, "Either") + } + err => panic!("expected `Error::FromLuaConversionError`, got {err:?}"), + }, + err => panic!("expected `Error::BadArgument`, got {err:?}"), + } + assert!( + err.to_string() + .starts_with("bad argument #1: error converting Lua string to Either"), + ); + } + err => panic!("expected `Error::CallbackError`, got {err:?}"), + } + + Ok(()) +} + +#[test] +fn test_char_into_lua() -> Result<()> { + let lua = Lua::new(); + + let v = '🦀'; + let v2 = v.into_lua(&lua)?; + assert_eq!(*v2.as_string().unwrap(), v.to_string()); + + Ok(()) +} + +#[test] +fn test_char_from_lua() -> Result<()> { + let lua = Lua::new(); + + assert_eq!(lua.convert::("A")?, 'A'); + assert_eq!(lua.convert::(65)?, 'A'); + assert_eq!(lua.convert::(128175)?, '💯'); + assert!( + lua.convert::(5456324) + .is_err_and(|e| e.to_string().contains("integer out of range")) + ); + assert!( + lua.convert::("hello") + .is_err_and(|e| e.to_string().contains("expected string to have exactly one char")) + ); + assert!( + lua.convert::(HashMap::::new()) + .is_err_and(|e| e.to_string().contains("expected string or integer")) + ); + + Ok(()) +} diff --git a/tests/debug.rs b/tests/debug.rs new file mode 100644 index 00000000..24c8adcf --- /dev/null +++ b/tests/debug.rs @@ -0,0 +1,15 @@ +use mlua::{Lua, Result}; + +#[test] +fn test_debug_format() -> Result<()> { + let lua = Lua::new(); + + // Globals + let globals = lua.globals(); + let dump = format!("{globals:#?}"); + assert!(dump.starts_with("{\n _G = table:")); + + // TODO: Other cases + + Ok(()) +} diff --git a/tests/error.rs b/tests/error.rs new file mode 100644 index 00000000..6f70f770 --- /dev/null +++ b/tests/error.rs @@ -0,0 +1,114 @@ +use std::error::Error as _; +use std::{fmt, io}; + +use mlua::{Error, ErrorContext, Lua, Result}; + +#[test] +fn test_error_context() -> Result<()> { + let lua = Lua::new(); + + let func = + lua.create_function(|_, ()| Err::<(), _>(Error::runtime("runtime error")).context("some context"))?; + lua.globals().set("func", func)?; + + let msg = lua + .load("local _, err = pcall(func); return tostring(err)") + .eval::()?; + assert!(msg.contains("some context")); + assert!(msg.contains("runtime error")); + + let func2 = lua.create_function(|lua, ()| { + lua.globals() + .get::("nonextant") + .with_context(|_| "failed to find global") + })?; + lua.globals().set("func2", func2)?; + + let msg2 = lua + .load("local _, err = pcall(func2); return tostring(err)") + .eval::()?; + assert!(msg2.contains("failed to find global")); + assert!(msg2.contains("error converting Lua nil to String")); + + // Rewrite context message and test `downcast_ref` + let func3 = lua.create_function(|_, ()| { + Err::<(), _>(Error::external(io::Error::new(io::ErrorKind::Other, "other"))) + .context("some context") + .context("some new context") + })?; + let err = func3.call::<()>(()).unwrap_err(); + let err = err.parent().unwrap(); + assert!(!err.to_string().contains("some context")); + assert!(err.to_string().contains("some new context")); + assert!(err.downcast_ref::().is_some()); + assert!(err.downcast_ref::().is_none()); + + Ok(()) +} + +#[test] +fn test_error_chain() -> Result<()> { + let lua = Lua::new(); + + // Check that `Error::ExternalError` creates a chain with a single element + let io_err = io::Error::new(io::ErrorKind::Other, "other"); + assert_eq!(Error::external(io_err).chain().count(), 1); + + let func = lua.create_function(|_, ()| { + let err = Error::external(io::Error::new(io::ErrorKind::Other, "other")).context("io error"); + Err::<(), _>(err) + })?; + let err = func.call::<()>(()).unwrap_err(); + assert_eq!(err.chain().count(), 3); + for (i, err) in err.chain().enumerate() { + match i { + 0 => assert!(matches!(err.downcast_ref(), Some(Error::CallbackError { .. }))), + 1 => assert!(matches!(err.downcast_ref(), Some(Error::WithContext { .. }))), + 2 => assert!(matches!(err.downcast_ref(), Some(io::Error { .. }))), + _ => unreachable!(), + } + } + + let err = err.parent().unwrap(); + assert!(err.source().is_none()); // The source is included to the `Display` output + assert!(err.to_string().contains("io error")); + assert!(err.to_string().contains("other")); + + Ok(()) +} + +#[test] +fn test_external_error() { + // `Error::external` should preserve `mlua::Error` + let runtime_err = Error::runtime("test error"); + let converted = Error::external(runtime_err); + assert!(matches!(converted, Error::RuntimeError(ref msg) if msg == "test error")); + + // Other errors should become `ExternalError` + let converted = Error::external(io::Error::other("other error")); + assert!(matches!(converted, Error::ExternalError(_))); + assert!(converted.downcast_ref::().is_some()); +} + +#[cfg(feature = "anyhow")] +#[test] +fn test_error_anyhow() -> Result<()> { + use mlua::IntoLua; + + let lua = Lua::new(); + + let err = anyhow::Error::msg("anyhow error"); + let val = err.into_lua(&lua)?; + assert!(val.is_error()); + assert_eq!(val.as_error().unwrap().to_string(), "anyhow error"); + + // Try Error -> anyhow::Error -> Error roundtrip + let err = Error::runtime("runtime error"); + let err = anyhow::Error::new(err); + let err = err.into_lua(&lua)?; + assert!(err.is_error()); + let err = err.as_error().unwrap(); + assert!(matches!(err, Error::RuntimeError(msg) if msg == "runtime error")); + + Ok(()) +} diff --git a/tests/function.rs b/tests/function.rs index 4045d50a..8cb75d4c 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -1,27 +1,37 @@ -use mlua::{Function, Lua, Result, String}; +use std::fmt; +use std::result::Result as StdResult; + +use mlua::{Error, Function, Lua, LuaString, Result, Table, Variadic}; #[test] -fn test_function() -> Result<()> { +fn test_function_call() -> Result<()> { let lua = Lua::new(); - let globals = lua.globals(); - lua.load( - r#" - function concat(arg1, arg2) - return arg1 .. arg2 - end - "#, - ) - .exec()?; + let concat = lua + .load(r#"function(arg1, arg2) return arg1 .. arg2 end"#) + .eval::()?; + assert_eq!(concat.call::(("foo", "bar"))?, "foobar"); - let concat = globals.get::<_, Function>("concat")?; - assert_eq!(concat.call::<_, String>(("foo", "bar"))?, "foobar"); + Ok(()) +} + +#[test] +fn test_function_call_error() -> Result<()> { + let lua = Lua::new(); + + let concat_err = lua + .load(r#"function(arg1, arg2) error("concat error") end"#) + .eval::()?; + match concat_err.call::(("foo", "bar")) { + Err(Error::RuntimeError(msg)) if msg.contains("concat error") => {} + other => panic!("unexpected result: {other:?}"), + } Ok(()) } #[test] -fn test_bind() -> Result<()> { +fn test_function_bind() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); @@ -38,72 +48,97 @@ fn test_bind() -> Result<()> { ) .exec()?; - let mut concat = globals.get::<_, Function>("concat")?; + let mut concat = globals.get::("concat")?; concat = concat.bind("foo")?; concat = concat.bind("bar")?; concat = concat.bind(("baz", "baf"))?; - assert_eq!( - concat.call::<_, String>(("hi", "wut"))?, - "foobarbazbafhiwut" - ); + assert_eq!(concat.call::(())?, "foobarbazbaf"); + assert_eq!(concat.call::(("hi", "wut"))?, "foobarbazbafhiwut"); + + let mut concat2 = globals.get::("concat")?; + concat2 = concat2.bind(())?; + assert_eq!(concat2.call::(())?, ""); + assert_eq!(concat2.call::(("ab", "cd"))?, "abcd"); Ok(()) } #[test] -fn test_rust_function() -> Result<()> { +#[cfg(not(target_arch = "wasm32"))] +fn test_function_bind_error() -> Result<()> { let lua = Lua::new(); - let globals = lua.globals(); - lua.load( - r#" - function lua_function() - return rust_function() - end - - -- Test to make sure chunk return is ignored - return 1 - "#, - ) - .exec()?; - - let lua_function = globals.get::<_, Function>("lua_function")?; - let rust_function = lua.create_function(|_, ()| Ok("hello"))?; - - globals.set("rust_function", rust_function)?; - assert_eq!(lua_function.call::<_, String>(())?, "hello"); + let func = lua.load(r#"function(...) end"#).eval::()?; + assert!(func.bind(Variadic::from_iter(1..1000000)).is_err()); + assert!(func.call::<()>(Variadic::from_iter(1..1000000)).is_err()); Ok(()) } #[test] -fn test_c_function() -> Result<()> { +fn test_function_environment() -> Result<()> { let lua = Lua::new(); + let globals = lua.globals(); - unsafe extern "C" fn c_function(state: *mut mlua::lua_State) -> std::os::raw::c_int { - let lua = Lua::init_from_ptr(state); - lua.globals().set("c_function", true).unwrap(); - 0 - } + // We must not get or set environment for C functions + let rust_func = lua.create_function(|_, ()| Ok("hello"))?; + assert_eq!(rust_func.environment(), None); + assert_eq!(rust_func.set_environment(globals.clone()).ok(), Some(false)); - let func = unsafe { lua.create_c_function(c_function)? }; - func.call(())?; - assert_eq!(lua.globals().get::<_, bool>("c_function")?, true); + // Test getting Lua function environment + globals.set("hello", "global")?; + let lua_func = lua + .load( + r#" + local t = "" + return function() + -- two upvalues + return t .. hello + end + "#, + ) + .eval::()?; + let lua_func2 = lua.load("return hello").into_function()?; + assert_eq!(lua_func.call::(())?, "global"); + assert_eq!(lua_func.environment().as_ref(), Some(&globals)); - Ok(()) -} + // Test changing the environment + let env = lua.create_table_from([("hello", "local")])?; + assert!(lua_func.set_environment(env.clone())?); + assert_eq!(lua_func.call::(())?, "local"); + assert_eq!(lua_func2.call::(())?, "global"); -#[cfg(not(feature = "luau"))] -#[test] -fn test_dump() -> Result<()> { - let lua = unsafe { Lua::unsafe_new() }; + // More complex case + lua.load( + r#" + local number = 15 + function lucky() return tostring("number is "..number) end + new_env = { + tostring = function() return tostring(number) end, + } + "#, + ) + .exec()?; + let lucky = globals.get::("lucky")?; + assert_eq!(lucky.call::(())?, "number is 15"); + let new_env = globals.get::
("new_env")?; + lucky.set_environment(new_env)?; + assert_eq!(lucky.call::(())?, "15"); - let concat_lua = lua - .load(r#"function(arg1, arg2) return arg1 .. arg2 end"#) + // Test inheritance + let lua_func2 = lua + .load(r#"return function() return (function() return hello end)() end"#) .eval::()?; - let concat = lua.load(&concat_lua.dump(false)).into_function()?; + assert!(lua_func2.set_environment(env.clone())?); + lua.gc_collect()?; + assert_eq!(lua_func2.call::(())?, "local"); - assert_eq!(concat.call::<_, String>(("foo", "bar"))?, "foobar"); + // Test getting environment set by chunk loader + let chunk = lua + .load("return hello") + .set_environment(lua.create_table_from([("hello", "chunk")])?) + .into_function()?; + assert_eq!(chunk.environment().unwrap().get::("hello")?, "chunk"); Ok(()) } @@ -120,44 +155,319 @@ fn test_function_info() -> Result<()> { end "#, ) - .set_name("source1")? + .set_name("source1") .exec()?; - let function1 = globals.get::<_, Function>("function1")?; - let function2 = function1.call::<_, Function>(())?; + let function1 = globals.get::("function1")?; + let function2 = function1.call::(())?; let function3 = lua.create_function(|_, ()| Ok(()))?; let function1_info = function1.info(); #[cfg(feature = "luau")] - assert_eq!(function1_info.name, Some(b"function1".to_vec())); - assert_eq!(function1_info.source, Some(b"source1".to_vec())); - assert_eq!(function1_info.line_defined, 2); + assert_eq!(function1_info.name.as_deref(), Some("function1")); + assert_eq!(function1_info.source.as_deref(), Some("source1")); + assert_eq!(function1_info.line_defined, Some(2)); #[cfg(not(feature = "luau"))] - assert_eq!(function1_info.last_line_defined, 4); - assert_eq!(function1_info.what, Some(b"Lua".to_vec())); + assert_eq!(function1_info.last_line_defined, Some(4)); + #[cfg(feature = "luau")] + assert_eq!(function1_info.last_line_defined, None); + assert_eq!(function1_info.what, "Lua"); let function2_info = function2.info(); assert_eq!(function2_info.name, None); - assert_eq!(function2_info.source, Some(b"source1".to_vec())); - assert_eq!(function2_info.line_defined, 3); + assert_eq!(function2_info.source.as_deref(), Some("source1")); + assert_eq!(function2_info.line_defined, Some(3)); #[cfg(not(feature = "luau"))] - assert_eq!(function2_info.last_line_defined, 3); - assert_eq!(function2_info.what, Some(b"Lua".to_vec())); + assert_eq!(function2_info.last_line_defined, Some(3)); + #[cfg(feature = "luau")] + assert_eq!(function2_info.last_line_defined, None); + assert_eq!(function2_info.what, "Lua"); let function3_info = function3.info(); assert_eq!(function3_info.name, None); - assert_eq!(function3_info.source, Some(b"=[C]".to_vec())); - assert_eq!(function3_info.line_defined, -1); - #[cfg(not(feature = "luau"))] - assert_eq!(function3_info.last_line_defined, -1); - assert_eq!(function3_info.what, Some(b"C".to_vec())); + assert_eq!(function3_info.source.as_deref(), Some("=[C]")); + assert_eq!(function3_info.line_defined, None); + assert_eq!(function3_info.last_line_defined, None); + assert_eq!(function3_info.what, "C"); - let print_info = globals.get::<_, Function>("print")?.info(); + let print_info = globals.get::("print")?.info(); #[cfg(feature = "luau")] - assert_eq!(print_info.name, Some(b"print".to_vec())); - assert_eq!(print_info.source, Some(b"=[C]".to_vec())); - assert_eq!(print_info.what, Some(b"C".to_vec())); - assert_eq!(print_info.line_defined, -1); + assert_eq!(print_info.name.as_deref(), Some("print")); + assert_eq!(print_info.source.as_deref(), Some("=[C]")); + assert_eq!(print_info.what, "C"); + assert_eq!(print_info.line_defined, None); + + // Function with upvalues and params + #[cfg(not(any(feature = "lua51", feature = "luajit")))] + { + let func_with_upvalues = lua + .load( + r#" + local x, y = ... + return function(a, ...) + return a*x + y + end + "#, + ) + .call::((10, 20))?; + let func_with_upvalues_info = func_with_upvalues.info(); + assert_eq!(func_with_upvalues_info.num_upvalues, 2); + assert_eq!(func_with_upvalues_info.num_params, 1); + assert_eq!(func_with_upvalues_info.is_vararg, true); + } + + Ok(()) +} + +#[cfg(not(feature = "luau"))] +#[test] +fn test_function_dump() -> Result<()> { + let lua = unsafe { Lua::unsafe_new() }; + + let concat_lua = lua + .load(r#"function(arg1, arg2) return arg1 .. arg2 end"#) + .eval::()?; + let concat = lua.load(&concat_lua.dump(false)).into_function()?; + + assert_eq!(concat.call::(("foo", "bar"))?, "foobar"); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_function_coverage() -> Result<()> { + let lua = Lua::new(); + + lua.set_compiler(mlua::Compiler::default().set_coverage_level(1)); + + let f = lua + .load( + r#"local s = "abc" + assert(#s == 3) + + function abc(i) + if i < 5 then + return 0 + else + return 1 + end + end + + (function() + (function() abc(10) end)() + end)() + "#, + ) + .into_function()?; + + f.call::<()>(())?; + + let mut report = Vec::new(); + f.coverage(|cov| { + report.push(cov); + }); + + assert_eq!( + report[0], + mlua::function::CoverageInfo { + function: None, + line_defined: 1, + depth: 0, + hits: vec![-1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1], + } + ); + assert_eq!( + report[1], + mlua::function::CoverageInfo { + function: Some("abc".into()), + line_defined: 4, + depth: 1, + hits: vec![-1, -1, -1, -1, -1, 1, 0, -1, 1, -1, -1, -1, -1, -1, -1, -1], + } + ); + assert_eq!( + report[2], + mlua::function::CoverageInfo { + function: None, + line_defined: 12, + depth: 1, + hits: vec![-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1], + } + ); + assert_eq!( + report[3], + mlua::function::CoverageInfo { + function: None, + line_defined: 13, + depth: 2, + hits: vec![-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1], + } + ); + + Ok(()) +} + +#[test] +fn test_function_pointer() -> Result<()> { + let lua = Lua::new(); + + let func1 = lua.load("return function() end").into_function()?; + let func2 = func1.call::(())?; + + assert_eq!(func1.to_pointer(), func1.clone().to_pointer()); + assert_ne!(func1.to_pointer(), func2.to_pointer()); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_function_deep_clone() -> Result<()> { + let lua = Lua::new(); + + lua.globals().set("a", 1)?; + let func1 = lua.load("a += 1; return a").into_function()?; + let func2 = func1.deep_clone()?; + + assert_ne!(func1.to_pointer(), func2.to_pointer()); + assert_eq!(func1.call::(())?, 2); + assert_eq!(func2.call::(())?, 3); + + // Check that for Rust functions deep_clone is just a clone + let rust_func = lua.create_function(|_, ()| Ok(42))?; + let rust_func2 = rust_func.deep_clone()?; + assert_eq!(rust_func.to_pointer(), rust_func2.to_pointer()); + + Ok(()) +} + +#[test] +fn test_function_wrap() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap(|s: LuaString, n| Ok::<_, Error>(s.to_str().unwrap().repeat(n))); + lua.globals().set("f", f)?; + lua.load(r#"assert(f("hello", 2) == "hellohello")"#) + .exec() + .unwrap(); + + // Return error + let ferr = Function::wrap(|| Err::<(), _>(Error::runtime("some error"))); + lua.globals().set("ferr", ferr)?; + lua.load( + r#" + local ok, err = pcall(ferr) + assert(not ok and tostring(err):find("some error")) + "#, + ) + .exec() + .unwrap(); + + // Return external error + #[derive(Debug)] + struct MyError(String); + impl fmt::Display for MyError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MyError: {}", self.0) + } + } + impl std::error::Error for MyError {} + + let fext = Function::wrap(|s: String| -> StdResult { + if s == "bad" { + return Err(MyError("bad input".into())); + } + Ok(format!("ok: {s}")) + }); + lua.globals().set("fext", fext)?; + lua.load(r#"assert(fext("hello") == "ok: hello")"#) + .exec() + .unwrap(); + lua.load( + r#" + local ok, err = pcall(fext, "bad") + assert(not ok and tostring(err):find("MyError: bad input")) + "#, + ) + .exec() + .unwrap(); + + // Mutable callback + let mut i = 0; + let fmut = Function::wrap_mut(move || { + i += 1; + Ok::<_, Error>(i) + }); + lua.globals().set("fmut", fmut)?; + lua.load(r#"fmut(); fmut(); assert(fmut() == 3)"#).exec().unwrap(); + + // Check mutable callback with error + let fmut_err = Function::wrap_mut(|| Err::<(), _>(Error::runtime("some error"))); + lua.globals().set("fmut_err", fmut_err)?; + lua.load( + r#" + local ok, err = pcall(fmut_err) + assert(not ok and tostring(err):find("some error")) + "#, + ) + .exec() + .unwrap(); + + // Check recursive mut callback error + let fmut = Function::wrap_mut(|f: Function| match f.call::<()>(&f) { + Err(Error::CallbackError { cause, .. }) => match cause.as_ref() { + Error::RecursiveMutCallback { .. } => Ok::<_, Error>(()), + other => panic!("incorrect result: {other:?}"), + }, + other => panic!("incorrect result: {other:?}"), + }); + let fmut = lua.convert::(fmut)?; + assert!(fmut.call::<()>(&fmut).is_ok()); + + Ok(()) +} + +#[test] +fn test_function_wrap_raw() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_raw(|| "hello"); + lua.globals().set("f", f)?; + lua.load(r#"assert(f() == "hello")"#).exec().unwrap(); + + // Return error + let ferr = Function::wrap_raw(|| Err::<(), _>("some error")); + lua.globals().set("ferr", ferr)?; + lua.load( + r#" + local _, err = ferr() + assert(err == "some error") + "#, + ) + .exec() + .unwrap(); + + // Mutable callback + let mut i = 0; + let fmut = Function::wrap_raw_mut(move || { + i += 1; + i + }); + lua.globals().set("fmut", fmut)?; + lua.load(r#"fmut(); fmut(); assert(fmut() == 3)"#).exec().unwrap(); + + // Check mutable callback with error + let fmut_err = Function::wrap_raw_mut(|| Err::<(), _>("some error")); + lua.globals().set("fmut_err", fmut_err)?; + lua.load( + r#" + local _, err = fmut_err() + assert(err == "some error") + "#, + ) + .exec() + .unwrap(); Ok(()) } diff --git a/tests/hooks.rs b/tests/hooks.rs index 2fbd33d8..9d68c84b 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -1,19 +1,15 @@ #![cfg(not(feature = "luau"))] -use std::cell::RefCell; -use std::ops::Deref; -use std::str; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Mutex}; -use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value}; +use mlua::debug::DebugEvent; +use mlua::{Error, HookTriggers, Lua, Result, Value, VmState}; #[test] -fn test_hook_triggers_bitor() { - let trigger = HookTriggers::on_calls() - | HookTriggers::on_returns() - | HookTriggers::every_line() - | HookTriggers::every_nth_instruction(5); +fn test_hook_triggers() { + let trigger = HookTriggers::new().on_calls().on_returns() + | HookTriggers::new().every_line().every_nth_instruction(5); assert!(trigger.on_calls); assert!(trigger.on_returns); @@ -27,10 +23,10 @@ fn test_line_counts() -> Result<()> { let hook_output = output.clone(); let lua = Lua::new(); - lua.set_hook(HookTriggers::every_line(), move |_lua, debug| { + lua.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Line); - hook_output.lock().unwrap().push(debug.curr_line()); - Ok(()) + hook_output.lock().unwrap().push(debug.current_line().unwrap()); + Ok(VmState::Continue) })?; lua.load( r#" @@ -59,14 +55,13 @@ fn test_function_calls() -> Result<()> { let hook_output = output.clone(); let lua = Lua::new(); - lua.set_hook(HookTriggers::on_calls(), move |_lua, debug| { + lua.set_hook(HookTriggers::ON_CALLS, move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Call); let names = debug.names(); let source = debug.source(); - let name = names.name.map(|s| str::from_utf8(s).unwrap().to_owned()); - let what = source.what.map(|s| str::from_utf8(s).unwrap().to_owned()); - hook_output.lock().unwrap().push((name, what)); - Ok(()) + let name = names.name.map(|s| s.into_owned()); + hook_output.lock().unwrap().push((name, source.what)); + Ok(VmState::Continue) })?; lua.load( @@ -80,20 +75,20 @@ fn test_function_calls() -> Result<()> { let output = output.lock().unwrap(); if cfg!(feature = "luajit") && lua.load("jit.version_num").eval::()? >= 20100 { + #[cfg(not(force_memory_limit))] + assert_eq!(*output, vec![(None, "main"), (Some("len".to_string()), "Lua")]); + #[cfg(force_memory_limit)] assert_eq!( *output, - vec![ - (None, Some("main".to_string())), - (Some("len".to_string()), Some("Lua".to_string())) - ] + vec![(None, "C"), (None, "main"), (Some("len".to_string()), "Lua")] ); } else { + #[cfg(not(force_memory_limit))] + assert_eq!(*output, vec![(None, "main"), (Some("len".to_string()), "C")]); + #[cfg(force_memory_limit)] assert_eq!( *output, - vec![ - (None, Some("main".to_string())), - (Some("len".to_string()), Some("C".to_string())) - ] + vec![(None, "C"), (None, "main"), (Some("len".to_string()), "C")] ); } @@ -104,24 +99,15 @@ fn test_function_calls() -> Result<()> { fn test_error_within_hook() -> Result<()> { let lua = Lua::new(); - lua.set_hook(HookTriggers::every_line(), |_lua, _debug| { - Err(Error::RuntimeError( - "Something happened in there!".to_string(), - )) + lua.set_hook(HookTriggers::EVERY_LINE, |_lua, _debug| { + Err(Error::runtime("Something happened in there!")) })?; - let err = lua - .load("x = 1") - .exec() - .expect_err("panic didn't propagate"); - + let err = lua.load("x = 1").exec().expect_err("panic didn't propagate"); match err { - Error::CallbackError { cause, .. } => match cause.deref() { - Error::RuntimeError(s) => assert_eq!(s, "Something happened in there!"), - _ => panic!("wrong callback error kind caught"), - }, - _ => panic!("wrong error kind caught"), - }; + Error::RuntimeError(msg) => assert_eq!(msg, "Something happened in there!"), + err => panic!("expected `RuntimeError` with a specific message, got {err:?}"), + } Ok(()) } @@ -136,13 +122,13 @@ fn test_limit_execution_instructions() -> Result<()> { let max_instructions = AtomicI64::new(10000); lua.set_hook( - HookTriggers::every_nth_instruction(30), + HookTriggers::new().every_nth_instruction(30), move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Count); if max_instructions.fetch_sub(30, Ordering::Relaxed) <= 30 { - Err(Error::RuntimeError("time's up".to_string())) + Err(Error::runtime("time's up")) } else { - Ok(()) + Ok(VmState::Continue) } }, )?; @@ -166,10 +152,8 @@ fn test_limit_execution_instructions() -> Result<()> { fn test_hook_removal() -> Result<()> { let lua = Lua::new(); - lua.set_hook(HookTriggers::every_nth_instruction(1), |_lua, _debug| { - Err(Error::RuntimeError( - "this hook should've been removed by this time".to_string(), - )) + lua.set_hook(HookTriggers::new().every_nth_instruction(1), |_lua, _debug| { + Err(Error::runtime("this hook should've been removed by this time")) })?; assert!(lua.load("local x = 1").exec().is_err()); @@ -179,10 +163,13 @@ fn test_hook_removal() -> Result<()> { Ok(()) } +// Having the code compiled (even not run) on macos and luajit causes a memory reference issue +// See https://github.com/LuaJIT/LuaJIT/issues/1099 +#[cfg(not(all(feature = "luajit", target_os = "macos")))] #[test] fn test_hook_swap_within_hook() -> Result<()> { thread_local! { - static TL_LUA: RefCell> = RefCell::new(None); + static TL_LUA: std::cell::RefCell> = Default::default(); } TL_LUA.with(|tl| { @@ -193,12 +180,13 @@ fn test_hook_swap_within_hook() -> Result<()> { tl.borrow() .as_ref() .unwrap() - .set_hook(HookTriggers::every_line(), move |lua, _debug| { + .set_hook(HookTriggers::EVERY_LINE, move |lua, _debug| { lua.globals().set("ok", 1i64)?; TL_LUA.with(|tl| { - tl.borrow().as_ref().unwrap().set_hook( - HookTriggers::every_line(), - move |lua, _debug| { + tl.borrow() + .as_ref() + .unwrap() + .set_hook(HookTriggers::EVERY_LINE, move |lua, _debug| { lua.load( r#" if ok ~= nil then @@ -211,10 +199,10 @@ fn test_hook_swap_within_hook() -> Result<()> { TL_LUA.with(|tl| { tl.borrow().as_ref().unwrap().remove_hook(); }); - Ok(()) - }, - ) - }) + Ok(VmState::Continue) + }) + })?; + Ok(VmState::Continue) }) })?; @@ -229,7 +217,112 @@ fn test_hook_swap_within_hook() -> Result<()> { "#, ) .exec()?; - assert_eq!(lua.globals().get::<_, i64>("ok")?, 2); + assert_eq!(lua.globals().get::("ok")?, 2); Ok(()) }) } + +#[test] +fn test_hook_threads() -> Result<()> { + let lua = Lua::new(); + + let func = lua + .load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .into_function()?; + let co = lua.create_thread(func)?; + + let output = Arc::new(Mutex::new(Vec::new())); + let hook_output = output.clone(); + co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { + assert_eq!(debug.event(), DebugEvent::Line); + hook_output.lock().unwrap().push(debug.current_line().unwrap()); + Ok(VmState::Continue) + })?; + + co.resume::<()>(())?; + lua.remove_hook(); + + let output = output.lock().unwrap(); + if cfg!(feature = "luajit") && lua.load("jit.version_num").eval::()? >= 20100 { + assert_eq!(*output, vec![2, 3, 4, 0, 4]); + } else { + assert_eq!(*output, vec![2, 3, 4]); + } + + Ok(()) +} + +#[test] +fn test_hook_yield() -> Result<()> { + let lua = Lua::new(); + + let func = lua + .load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .into_function()?; + let co = lua.create_thread(func)?; + + co.set_hook(HookTriggers::EVERY_LINE, move |_lua, _debug| Ok(VmState::Yield))?; + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + { + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.is_finished()); + } + #[cfg(any(feature = "lua51", feature = "lua52", feature = "luajit"))] + { + assert!( + matches!(co.resume::<()>(()), Err(Error::RuntimeError(err)) if err.contains("attempt to yield from a hook")) + ); + assert!(co.is_error()); + } + + Ok(()) +} + +#[test] +fn test_global_hook() -> Result<()> { + let lua = Lua::new(); + + let counter = Arc::new(AtomicI64::new(0)); + let hook_counter = counter.clone(); + lua.set_global_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { + assert_eq!(debug.event(), DebugEvent::Line); + hook_counter.fetch_add(1, Ordering::Relaxed); + Ok(VmState::Continue) + })?; + + let thread = lua.create_thread( + lua.load( + r#" + local x = 2 + 3 + local y = x * 63 + coroutine.yield() + local z = string.len(x..", "..y) + "#, + ) + .into_function()?, + )?; + + thread.resume::<()>(()).unwrap(); + lua.remove_global_hook(); + thread.resume::<()>(()).unwrap(); + assert!(thread.is_finished()); + assert_eq!(counter.load(Ordering::Relaxed), 3); + + Ok(()) +} diff --git a/tests/luau.rs b/tests/luau.rs index 913ece20..6b5c9296 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -1,51 +1,84 @@ #![cfg(feature = "luau")] -use std::env; -use std::fs; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::cell::Cell; +use std::fmt::Debug; +use std::os::raw::c_void; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering}; -use mlua::{Compiler, Error, Lua, Result, Table, ThreadStatus, Value, VmState}; +use mlua::{ + Compiler, Error, Function, Lua, LuaOptions, ObjectLike, Result, StdLib, Table, Value, Vector, VmState, +}; #[test] -fn test_require() -> Result<()> { +fn test_version() -> Result<()> { let lua = Lua::new(); + assert!(lua.globals().get::("_VERSION")?.starts_with("Luau 0.")); + Ok(()) +} + +#[cfg(not(feature = "luau-vector4"))] +#[test] +fn test_vectors() -> Result<()> { + let lua = Lua::new(); + + let v: Vector = lua + .load("vector.create(1, 2, 3) + vector.create(3, 2, 1)") + .eval()?; + assert_eq!(v, [4.0, 4.0, 4.0]); + + // Test conversion into Rust array + let v: [f64; 3] = lua.load("vector.create(1, 2, 3)").eval()?; + assert!(v == [1.0, 2.0, 3.0]); - let temp_dir = tempfile::tempdir().unwrap(); - fs::write( - temp_dir.path().join("module.luau"), + // Test vector methods + lua.load( r#" - counter = counter or 0 - return counter + 1 + local v = vector.create(1, 2, 3) + assert(v.x == 1) + assert(v.y == 2) + assert(v.z == 3) "#, - )?; + ) + .exec()?; - env::set_var("LUAU_PATH", temp_dir.path().join("?.luau")); + // Test vector methods (fastcall) lua.load( r#" - local module = require("module") - assert(module == 1) - module = require("module") - assert(module == 1) + local v = vector.create(1, 2, 3) + assert(v.x == 1) + assert(v.y == 2) + assert(v.z == 3) "#, ) - .exec() + .set_compiler(Compiler::new().set_vector_ctor("vector")) + .exec()?; + + Ok(()) } +#[cfg(feature = "luau-vector4")] #[test] fn test_vectors() -> Result<()> { let lua = Lua::new(); - let v: [f32; 3] = lua.load("vector(1, 2, 3) + vector(3, 2, 1)").eval()?; - assert_eq!(v, [4.0, 4.0, 4.0]); + let v: Vector = lua + .load("vector.create(1, 2, 3, 4) + vector.create(4, 3, 2, 1)") + .eval()?; + assert_eq!(v, [5.0, 5.0, 5.0, 5.0]); + + // Test conversion into Rust array + let v: [f64; 4] = lua.load("vector.create(1, 2, 3, 4)").eval()?; + assert!(v == [1.0, 2.0, 3.0, 4.0]); // Test vector methods lua.load( r#" - local v = vector(1, 2, 3) + local v = vector.create(1, 2, 3, 4) assert(v.x == 1) assert(v.y == 2) assert(v.z == 3) + assert(v.w == 4) "#, ) .exec()?; @@ -53,13 +86,56 @@ fn test_vectors() -> Result<()> { // Test vector methods (fastcall) lua.load( r#" - local v = vector(1, 2, 3) + local v = vector.create(1, 2, 3, 4) assert(v.x == 1) assert(v.y == 2) assert(v.z == 3) + assert(v.w == 4) "#, ) - .set_compiler(Compiler::new().set_vector_ctor(Some("vector".to_string()))) + .set_compiler(Compiler::new().set_vector_ctor("vector")) + .exec()?; + + Ok(()) +} + +#[cfg(not(feature = "luau-vector4"))] +#[test] +fn test_vector_metatable() -> Result<()> { + let lua = Lua::new(); + + let vector_mt = lua + .load( + r#" + { + __index = { + new = vector.create, + + product = function(a, b) + return vector.create(a.x * b.x, a.y * b.y, a.z * b.z) + end + } + } + "#, + ) + .eval::
()?; + vector_mt.set_metatable(Some(vector_mt.clone()))?; + lua.set_type_metatable::(Some(vector_mt.clone())); + lua.globals().set("Vector3", vector_mt)?; + + let compiler = Compiler::new() + .set_vector_ctor("Vector3.new") + .set_vector_type("Vector3"); + + // Test vector methods (fastcall) + lua.load( + r#" + local v = Vector3.new(1, 2, 3) + local v2 = v:product(Vector3.new(2, 3, 4)) + assert(v2.x == 2 and v2.y == 6 and v2.z == 12) + "#, + ) + .set_compiler(compiler) .exec()?; Ok(()) @@ -69,18 +145,33 @@ fn test_vectors() -> Result<()> { fn test_readonly_table() -> Result<()> { let lua = Lua::new(); - let t = lua.create_table()?; + let t = lua.create_sequence_from([1])?; assert!(!t.is_readonly()); t.set_readonly(true); assert!(t.is_readonly()); - match t.set("key", "value") { - Err(Error::RuntimeError(err)) if err.contains("Attempt to modify a readonly table") => {} - r => panic!( - "expected RuntimeError(...) with a specific message, got {:?}", - r - ), - }; + #[track_caller] + fn check_readonly_error(res: Result) { + match res { + Err(Error::RuntimeError(e)) if e.contains("attempt to modify a readonly table") => {} + r => panic!("expected RuntimeError(...) with a specific message, got {r:?}"), + } + } + + check_readonly_error(t.set("key", "value")); + check_readonly_error(t.raw_set("key", "value")); + check_readonly_error(t.raw_insert(1, "value")); + check_readonly_error(t.raw_remove(1)); + check_readonly_error(t.push("value")); + check_readonly_error(t.pop::()); + check_readonly_error(t.raw_push("value")); + check_readonly_error(t.raw_pop::()); + + // Special case + match t.set_metatable(None) { + Err(Error::RuntimeError(e)) if e.contains("attempt to modify a readonly table") => {} + r => panic!("expected RuntimeError(...) with a specific message, got {r:?}"), + } Ok(()) } @@ -94,27 +185,70 @@ fn test_sandbox() -> Result<()> { lua.load("global = 123").exec()?; let n: i32 = lua.load("return global").eval()?; assert_eq!(n, 123); - assert_eq!(lua.globals().get::<_, Option>("global")?, Some(123)); + assert_eq!(lua.globals().get::>("global")?, Some(123)); // Threads should inherit "main" globals - let f = lua.create_function(|lua, ()| lua.globals().get::<_, i32>("global"))?; + let f = lua.create_function(|lua, ()| lua.globals().get::("global"))?; let co = lua.create_thread(f.clone())?; - assert_eq!(co.resume::<_, Option>(())?, Some(123)); + assert_eq!(co.resume::>(())?, Some(123)); // Sandboxed threads should also inherit "main" globals let co = lua.create_thread(f)?; co.sandbox()?; - assert_eq!(co.resume::<_, Option>(())?, Some(123)); + assert_eq!(co.resume::>(())?, Some(123)); + + // collectgarbage should be restricted in sandboxed mode + let collectgarbage = lua.globals().get::("collectgarbage")?; + for arg in ["collect", "stop", "restart", "step", "isrunning"] { + let err = collectgarbage.call::<()>(arg).err().unwrap().to_string(); + assert!(err.contains("collectgarbage called with invalid option")); + } + assert!(collectgarbage.call::("count").unwrap() > 0); lua.sandbox(false)?; // Previously set variable `global` should be cleared now - assert_eq!(lua.globals().get::<_, Option>("global")?, None); + assert_eq!(lua.globals().get::>("global")?, None); // Readonly flags should be cleared as well - let table = lua.globals().get::<_, Table>("table")?; + let table = lua.globals().get::
("table")?; table.set("test", "test")?; + // collectgarbage should work now + for arg in ["collect", "stop", "restart", "count", "step", "isrunning"] { + collectgarbage.call::<()>(arg).unwrap(); + } + + Ok(()) +} + +#[test] +fn test_sandbox_safeenv() -> Result<()> { + let lua = Lua::new(); + + lua.sandbox(true)?; + lua.globals().set("state", lua.create_table()?)?; + lua.globals().set_safeenv(false); + lua.load("state.a = 123").exec()?; + let a: i32 = lua.load("state.a = 321; return state.a").eval()?; + assert_eq!(a, 321); + + Ok(()) +} + +#[test] +fn test_sandbox_nolibs() -> Result<()> { + let lua = Lua::new_with(StdLib::NONE, LuaOptions::default()).unwrap(); + + lua.sandbox(true)?; + lua.load("global = 123").exec()?; + let n: i32 = lua.load("return global").eval()?; + assert_eq!(n, 123); + assert_eq!(lua.globals().get::>("global")?, Some(123)); + + lua.sandbox(false)?; + assert_eq!(lua.globals().get::>("global")?, None); + Ok(()) } @@ -125,20 +259,20 @@ fn test_sandbox_threads() -> Result<()> { let f = lua.create_function(|lua, v: Value| lua.globals().set("global", v))?; let co = lua.create_thread(f.clone())?; - co.resume(321)?; + co.resume::<()>(321)?; // The main state should see the `global` variable (as the thread is not sandboxed) - assert_eq!(lua.globals().get::<_, Option>("global")?, Some(321)); + assert_eq!(lua.globals().get::>("global")?, Some(321)); let co = lua.create_thread(f.clone())?; co.sandbox()?; - co.resume(123)?; + co.resume::<()>(123)?; // The main state should see the previous `global` value (as the thread is sandboxed) - assert_eq!(lua.globals().get::<_, Option>("global")?, Some(321)); + assert_eq!(lua.globals().get::>("global")?, Some(321)); // Try to reset the (sandboxed) thread co.reset(f)?; - co.resume(111)?; - assert_eq!(lua.globals().get::<_, Option>("global")?, Some(111)); + co.resume::<()>(111)?; + assert_eq!(lua.globals().get::>("global")?, Some(111)); Ok(()) } @@ -150,7 +284,7 @@ fn test_interrupts() -> Result<()> { let interrupts_count = Arc::new(AtomicU64::new(0)); let interrupts_count2 = interrupts_count.clone(); - lua.set_interrupt(move || { + lua.set_interrupt(move |_| { interrupts_count2.fetch_add(1, Ordering::Relaxed); Ok(VmState::Continue) }); @@ -163,7 +297,7 @@ fn test_interrupts() -> Result<()> { "#, ) .into_function()?; - f.call(())?; + f.call::<()>(())?; assert!(interrupts_count.load(Ordering::Relaxed) > 0); @@ -172,7 +306,7 @@ fn test_interrupts() -> Result<()> { // let yield_count = Arc::new(AtomicU64::new(0)); let yield_count2 = yield_count.clone(); - lua.set_interrupt(move || { + lua.set_interrupt(move |_| { if yield_count2.fetch_add(1, Ordering::Relaxed) == 1 { return Ok(VmState::Yield); } @@ -189,26 +323,235 @@ fn test_interrupts() -> Result<()> { ) .into_function()?, )?; - co.resume(())?; - assert_eq!(co.status(), ThreadStatus::Resumable); + co.resume::<()>(())?; + assert!(co.is_resumable()); let result: i32 = co.resume(())?; assert_eq!(result, 6); assert_eq!(yield_count.load(Ordering::Relaxed), 7); - assert_eq!(co.status(), ThreadStatus::Unresumable); + assert!(co.is_finished()); + + // Test no yielding at non-yieldable points + yield_count.store(0, Ordering::Relaxed); + let co = lua.create_thread(lua.create_function(|lua, arg: Value| { + (lua.load("return (function(x) return x end)(...)")).call::(arg) + })?)?; + let res = co.resume::("abc")?; + assert_eq!(res, "abc".to_string()); + assert_eq!(yield_count.load(Ordering::Relaxed), 3); // // Test errors in interrupts // - lua.set_interrupt(|| Err(Error::RuntimeError("error from interrupt".into()))); - match f.call::<_, ()>(()) { - Err(Error::CallbackError { cause, .. }) => match *cause { - Error::RuntimeError(ref m) if m == "error from interrupt" => {} - ref e => panic!("expected RuntimeError with a specific message, got {:?}", e), - }, - r => panic!("expected CallbackError, got {:?}", r), + lua.set_interrupt(|_| Err(Error::runtime("error from interrupt"))); + match f.call::<()>(()) { + Err(Error::RuntimeError(ref msg)) => assert_eq!(msg, "error from interrupt"), + res => panic!("expected `RuntimeError` with a specific message, got {res:?}"), } lua.remove_interrupt(); Ok(()) } + +#[test] +fn test_fflags() { + // We cannot really on any particular feature flag to be present + assert!(Lua::set_fflag("UnknownFlag", true).is_err()); +} + +#[test] +fn test_thread_events() -> Result<()> { + let lua = Lua::new(); + + let count = Arc::new(AtomicU64::new(0)); + let thread_data: Arc<(AtomicPtr, AtomicBool)> = Arc::new(Default::default()); + + let (count2, thread_data2) = (count.clone(), thread_data.clone()); + lua.set_thread_creation_callback(move |_, thread| { + count2.fetch_add(1, Ordering::Relaxed); + (thread_data2.0).store(thread.to_pointer() as *mut _, Ordering::Relaxed); + thread_data2.1.store(false, Ordering::Relaxed); + Ok(()) + }); + let (count3, thread_data3) = (count.clone(), thread_data.clone()); + lua.set_thread_collection_callback(move |thread_ptr| { + count3.fetch_add(1, Ordering::Relaxed); + if thread_data3.0.load(Ordering::Relaxed) == thread_ptr.0 { + thread_data3.1.store(true, Ordering::Relaxed); + } + }); + + let t = lua.create_thread(lua.load("return 123").into_function()?)?; + assert_eq!(count.load(Ordering::Relaxed), 1); + let t_ptr = t.to_pointer(); + assert_eq!(t_ptr, thread_data.0.load(Ordering::Relaxed)); + assert!(!thread_data.1.load(Ordering::Relaxed)); + + // Thead will be destroyed after GC cycle + drop(t); + lua.gc_collect()?; + assert_eq!(count.load(Ordering::Relaxed), 2); + assert_eq!(t_ptr, thread_data.0.load(Ordering::Relaxed)); + assert!(thread_data.1.load(Ordering::Relaxed)); + + // Check that recursion is not allowed + let count4 = count.clone(); + lua.set_thread_creation_callback(move |lua, _value| { + count4.fetch_add(1, Ordering::Relaxed); + let _ = lua.create_thread(lua.load("return 123").into_function().unwrap())?; + Ok(()) + }); + let t = lua.create_thread(lua.load("return 123").into_function()?)?; + assert_eq!(count.load(Ordering::Relaxed), 3); + + lua.remove_thread_callbacks(); + drop(t); + lua.gc_collect()?; + assert_eq!(count.load(Ordering::Relaxed), 3); + + // Test error inside callback + lua.set_thread_creation_callback(move |_, _| Err(Error::runtime("error when processing thread event"))); + let result = lua.create_thread(lua.load("return 123").into_function()?); + assert!(result.is_err()); + assert!( + matches!(result, Err(Error::RuntimeError(err)) if err.contains("error when processing thread event")) + ); + + // Test context switch when running Lua script + let count = Cell::new(0); + lua.set_thread_creation_callback(move |_, _| { + count.set(count.get() + 1); + if count.get() == 2 { + return Err(Error::runtime("thread limit exceeded")); + } + Ok(()) + }); + let result = lua + .load( + r#" + local co = coroutine.wrap(function() return coroutine.create(print) end) + co() + "#, + ) + .exec(); + assert!(result.is_err()); + assert!(matches!(result, Err(Error::RuntimeError(err)) if err.contains("thread limit exceeded"))); + + Ok(()) +} + +#[test] +fn test_loadstring() -> Result<()> { + let lua = Lua::new(); + + let f = lua.load(r#"loadstring("return 123")"#).eval::()?; + assert_eq!(f.call::(())?, 123); + + let err = lua + .load(r#"loadstring("retur 123", "chunk")"#) // typos:ignore + .exec() + .err() + .unwrap(); + assert!(err.to_string().contains( + r#"syntax error: [string "chunk"]:1: Incomplete statement: expected assignment or a function call"# + )); + + Ok(()) +} + +#[test] +fn test_typeof_error() -> Result<()> { + let lua = Lua::new(); + + let err = Error::runtime("just a test error"); + let res = lua.load("return typeof(...)").call::(err)?; + assert_eq!(res, "error"); + + Ok(()) +} + +#[test] +fn test_memory_category() -> Result<()> { + let lua = Lua::new(); + + lua.set_memory_category("main").unwrap(); + + // Invalid category names should be rejected + let err = lua.set_memory_category("invalid$"); + assert!(err.is_err()); + + for i in 0..254 { + let name = format!("category_{}", i); + lua.set_memory_category(&name).unwrap(); + } + // 255th category should fail + let err = lua.set_memory_category("category_254"); + assert!(err.is_err()); + + Ok(()) +} + +#[test] +fn test_heap_dump() -> Result<()> { + let lua = Lua::new(); + + // Assign a new memory category and create few objects + lua.set_memory_category("test_category")?; + let _t = lua.create_table()?; + let _ud = lua.create_any_userdata("hello, world")?; + + let dump = lua.heap_dump()?; + + assert!(dump.size() > 0); + let size_by_category = dump.size_by_category(); + assert_eq!(size_by_category.len(), 2); + assert!(size_by_category.contains_key("test_category")); + assert!(size_by_category["main"] < dump.size()); + + // Check size by type within the category + let size_by_type = dump.size_by_type(Some("test_category")); + assert!(!size_by_type.is_empty()); + assert!(size_by_type.contains_key("table")); + assert!(size_by_type.contains_key("userdata")); + // Try non-existent category + let size_by_type2 = dump.size_by_type(Some("non_existent_category")); + assert!(size_by_type2.is_empty()); + // Remove category filter + let size_by_type_all = dump.size_by_type(None); + assert!(size_by_type.len() < size_by_type_all.len()); + + // Check size by userdata type within the category + let size_by_udtype = dump.size_by_userdata(Some("test_category")); + assert_eq!(size_by_udtype.len(), 1); + assert!(size_by_udtype.contains_key("&str")); + assert_eq!(size_by_udtype["&str"].0, 1); + // Try non-existent category + let size_by_udtype2 = dump.size_by_userdata(Some("non_existent_category")); + assert!(size_by_udtype2.is_empty()); + // Remove category filter + let size_by_udtype_all = dump.size_by_userdata(None); + assert!(size_by_udtype.len() < size_by_udtype_all.len()); + + Ok(()) +} + +#[test] +fn test_integer64_type() -> Result<()> { + let lua = Lua::new(); + + _ = Lua::set_fflag("LuauIntegerType", true); + + let integer_lib = lua.globals().get::
("integer")?; + let n = integer_lib.call_function::("create", 42)?; + assert_eq!(n, 42); + + let n: i64 = lua.load("return 42i").eval()?; + assert_eq!(n, 42); + let n: i64 = lua.load("return -42i").eval()?; + assert_eq!(n, -42); + + Ok(()) +} + +#[path = "luau/require.rs"] +mod require; diff --git a/tests/luau/require.rs b/tests/luau/require.rs new file mode 100644 index 00000000..eace354c --- /dev/null +++ b/tests/luau/require.rs @@ -0,0 +1,292 @@ +use std::io::Result as IoResult; +use std::result::Result as StdResult; + +use mlua::luau::{FsRequirer, NavigateError, Require}; +use mlua::{Error, FromLua, IntoLua, Lua, MultiValue, Result, Value}; + +fn run_require(lua: &Lua, path: impl IntoLua) -> Result { + lua.load(r#"return require(...)"#).call(path) +} + +fn run_require_pcall(lua: &Lua, path: impl IntoLua) -> Result { + lua.load(r#"return pcall(require, ...)"#).call(path) +} + +#[track_caller] +fn get_value(value: &Value, key: impl IntoLua) -> V { + value.as_table().unwrap().get(key).unwrap() +} + +#[track_caller] +fn get_str(value: &Value, key: impl IntoLua) -> String { + get_value(value, key) +} + +#[test] +fn test_require_errors() { + let lua = Lua::new(); + + // RequireAbsolutePath + let res = run_require(&lua, "/an/absolute/path"); + assert!(res.is_err()); + assert!( + (res.unwrap_err().to_string()).contains("require path must start with a valid prefix: ./, ../, or @") + ); + + // RequireUnprefixedPath + let res = run_require(&lua, "an/unprefixed/path"); + assert!(res.is_err()); + assert!( + (res.unwrap_err().to_string()).contains("require path must start with a valid prefix: ./, ../, or @") + ); + + // Pass non-string to require + let res = run_require(&lua, true); + assert!(res.is_err()); + assert!( + (res.unwrap_err().to_string()) + .contains("bad argument #1 to 'require' (string expected, got boolean)") + ); + + // Require from loadstring + let res = lua + .load(r#"return loadstring("require('./a/relative/path')")()"#) + .eval::(); + assert!(res.is_err()); + assert!((res.unwrap_err().to_string()).contains("require is not supported in this context")); + + // RequireAliasThatDoesNotExist + let res = run_require(&lua, "@this.alias.does.not.exist"); + assert!(res.is_err()); + assert!((res.unwrap_err().to_string()).contains("@this.alias.does.not.exist is not a valid alias")); + + // IllegalAlias + let res = run_require(&lua, "@"); + assert!(res.is_err()); + assert!((res.unwrap_err().to_string()).contains("@ is not a valid alias")); + + // Test throwing mlua::Error + struct MyRequire(FsRequirer); + + impl Require for MyRequire { + fn is_require_allowed(&self, chunk_name: &str) -> bool { + self.0.is_require_allowed(chunk_name) + } + + fn reset(&mut self, _chunk_name: &str) -> StdResult<(), NavigateError> { + Err(Error::runtime("test error"))? + } + + fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> { + self.0.jump_to_alias(path) + } + + fn to_parent(&mut self) -> StdResult<(), NavigateError> { + self.0.to_parent() + } + + fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> { + self.0.to_child(name) + } + + fn has_module(&self) -> bool { + self.0.has_module() + } + + fn cache_key(&self) -> String { + self.0.cache_key() + } + + fn has_config(&self) -> bool { + self.0.has_config() + } + + fn config(&self) -> IoResult> { + self.0.config() + } + + fn loader(&self, lua: &Lua) -> Result { + self.0.loader(lua) + } + } + + let require = lua.create_require_function(MyRequire(FsRequirer::new())).unwrap(); + lua.globals().set("require", require).unwrap(); + let res = lua.load(r#"return require('./a/relative/path')"#).exec(); + assert!((res.unwrap_err().to_string()).contains("test error")); +} + +#[test] +fn test_require_without_config() { + let lua = Lua::new(); + + // RequireSimpleRelativePath + let res = run_require(&lua, "./tests/luau/require/without_config/dependency").unwrap(); + assert_eq!("result from dependency", get_str(&res, 1)); + + // RequireSimpleRelativePathWithinPcall + let res = run_require_pcall(&lua, "./tests/luau/require/without_config/dependency").unwrap(); + assert!(res[0].as_boolean().unwrap()); + assert_eq!("result from dependency", get_str(&res[1], 1)); + + // RequireRelativeToRequiringFile + let res = run_require(&lua, "./tests/luau/require/without_config/module").unwrap(); + assert_eq!("result from dependency", get_str(&res, 1)); + assert_eq!("required into module", get_str(&res, 2)); + + // RequireLua + let res = run_require(&lua, "./tests/luau/require/without_config/lua_dependency").unwrap(); + assert_eq!("result from lua_dependency", get_str(&res, 1)); + + // RequireInitLuau + let res = run_require(&lua, "./tests/luau/require/without_config/luau").unwrap(); + assert_eq!("result from init.luau", get_str(&res, 1)); + + // RequireInitLua + let res = run_require(&lua, "./tests/luau/require/without_config/lua").unwrap(); + assert_eq!("result from init.lua", get_str(&res, 1)); + + // RequireSubmoduleUsingSelfIndirectly + let res = run_require(&lua, "./tests/luau/require/without_config/nested_module_requirer").unwrap(); + assert_eq!("result from submodule", get_str(&res, 1)); + + // RequireSubmoduleUsingSelfDirectly + let res = run_require(&lua, "./tests/luau/require/without_config/nested").unwrap(); + assert_eq!("result from submodule", get_str(&res, 1)); + + // CannotRequireInitLuauDirectly + let res = run_require(&lua, "./tests/luau/require/without_config/nested/init"); + assert!(res.is_err()); + assert!((res.unwrap_err().to_string()).contains("could not resolve child component \"init\"")); + + // RequireNestedInits + let res = run_require(&lua, "./tests/luau/require/without_config/nested_inits_requirer").unwrap(); + assert_eq!("result from nested_inits/init", get_str(&res, 1)); + assert_eq!("required into module", get_str(&res, 2)); + + // RequireWithFileAmbiguity + let res = run_require( + &lua, + "./tests/luau/require/without_config/ambiguous_file_requirer", + ); + assert!(res.is_err()); + assert!( + (res.unwrap_err().to_string()) + .contains("could not resolve child component \"dependency\" (ambiguous)") + ); + + // RequireWithDirectoryAmbiguity + let res = run_require( + &lua, + "./tests/luau/require/without_config/ambiguous_directory_requirer", + ); + assert!(res.is_err()); + assert!( + (res.unwrap_err().to_string()) + .contains("could not resolve child component \"dependency\" (ambiguous)") + ); + + // CheckCachedResult + let res = run_require(&lua, "./tests/luau/require/without_config/validate_cache").unwrap(); + assert!(res.is_table()); +} + +fn test_require_with_config_inner(r#type: &str) { + let lua = Lua::new(); + + let base_path = format!("./tests/luau/require/{type}"); + + // RequirePathWithAlias + let res = run_require(&lua, format!("{base_path}/src/alias_requirer")).unwrap(); + assert_eq!("result from dependency", get_str(&res, 1)); + + // RequirePathWithAlias (case-insensitive) + let res2 = run_require(&lua, format!("{base_path}/src/alias_requirer_uc")).unwrap(); + assert_eq!("result from dependency", get_str(&res2, 1)); + assert_eq!(res.to_pointer(), res2.to_pointer()); + + // RequirePathWithParentAlias + let res = run_require(&lua, format!("{base_path}/src/parent_alias_requirer")).unwrap(); + assert_eq!("result from other_dependency", get_str(&res, 1)); + + // RequirePathWithAliasPointingToDirectory + let res = run_require(&lua, format!("{base_path}/src/directory_alias_requirer")).unwrap(); + assert_eq!("result from subdirectory_dependency", get_str(&res, 1)); + + // RequireChainedAliasesSuccess + let res = run_require( + &lua, + format!("{base_path}/chained_aliases/subdirectory/successful_requirer"), + ) + .unwrap(); + assert_eq!("result from inner_dependency", get_str(&get_value(&res, 1), 1)); + assert_eq!("result from outer_dependency", get_str(&get_value(&res, 2), 1)); + + // RequireChainedAliasesFailureCyclic + let res = run_require( + &lua, + format!("{base_path}/chained_aliases/subdirectory/failing_requirer_cyclic"), + ); + assert!(res.is_err()); + let err_msg = "error requiring module \"@cyclicentry\": detected alias cycle (@cyclic1 -> @cyclic2 -> @cyclic3 -> @cyclic1)"; + assert!(res.unwrap_err().to_string().contains(err_msg)); + + // RequireChainedAliasesFailureMissing + let res = run_require( + &lua, + format!("{base_path}/chained_aliases/subdirectory/failing_requirer_missing"), + ); + assert!(res.is_err()); + let err_msg = "error requiring module \"@brokenchain\": @missing is not a valid alias"; + assert!(res.unwrap_err().to_string().contains(err_msg)); +} + +#[test] +fn test_require_with_config() { + test_require_with_config_inner("with_config"); +} + +#[test] +fn test_require_with_config_luau() { + test_require_with_config_inner("with_config_luau"); +} + +#[cfg(all(feature = "async", not(windows)))] +#[tokio::test] +async fn test_async_require() -> Result<()> { + let lua = Lua::new(); + + let temp_dir = tempfile::tempdir().unwrap(); + let temp_path = temp_dir.path().join("async_chunk.luau"); + std::fs::write( + &temp_path, + r#" + sleep_ms(10) + return "result_after_async_sleep" + "#, + ) + .unwrap(); + + lua.globals().set( + "sleep_ms", + lua.create_async_function(|_, ms: u64| async move { + tokio::time::sleep(std::time::Duration::from_millis(ms)).await; + Ok(()) + })?, + )?; + lua.globals().set("tmp_dir", temp_dir.path().to_str().unwrap())?; + lua.globals().set( + "curr_dir_components", + std::env::current_dir().unwrap().components().count(), + )?; + + lua.load( + r#" + local path_to_root = string.rep("/..", curr_dir_components - 1) + local result = require(`.{path_to_root}{tmp_dir}/async_chunk`) + assert(result == "result_after_async_sleep") + "#, + ) + .exec_async() + .await +} diff --git a/tests/luau/require/with_config/.luaurc b/tests/luau/require/with_config/.luaurc new file mode 100644 index 00000000..2b64ad06 --- /dev/null +++ b/tests/luau/require/with_config/.luaurc @@ -0,0 +1,6 @@ +{ + "aliases": { + "dep": "./this_should_be_overwritten_by_child_luaurc", + "otherdep": "./src/other_dependency" + } +} diff --git a/tests/luau/require/with_config/chained_aliases/.luaurc b/tests/luau/require/with_config/chained_aliases/.luaurc new file mode 100644 index 00000000..42e61fcd --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/.luaurc @@ -0,0 +1,9 @@ +{ + "aliases":{ + "outer": "./", + "cyclicentry": "@cyclic1", + "cyclic1": "@cyclic2", + "cyclic2": "@cyclic3", + "cyclic3": "@cyclic1" + } +} diff --git a/tests/luau/require/with_config/chained_aliases/outer_dependency.luau b/tests/luau/require/with_config/chained_aliases/outer_dependency.luau new file mode 100644 index 00000000..69ffda57 --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/outer_dependency.luau @@ -0,0 +1 @@ +return {"result from outer_dependency"} diff --git a/tests/luau/require/with_config/chained_aliases/subdirectory/.luaurc b/tests/luau/require/with_config/chained_aliases/subdirectory/.luaurc new file mode 100644 index 00000000..96b72086 --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/subdirectory/.luaurc @@ -0,0 +1,10 @@ +{ + "aliases":{ + "passthroughinner": "./inner_dependency", + "passthroughouter": "@outer", + "dep": "@passthroughinner", + "outerdep": "@outer/outer_dependency", + "outerdir": "@passthroughouter", + "brokenchain": "@missing" + } +} diff --git a/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_cyclic.luau b/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_cyclic.luau new file mode 100644 index 00000000..9f5ed488 --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_cyclic.luau @@ -0,0 +1 @@ +return require("@cyclicentry") diff --git a/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_missing.luau b/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_missing.luau new file mode 100644 index 00000000..75703849 --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/subdirectory/failing_requirer_missing.luau @@ -0,0 +1 @@ +return require("@brokenchain") diff --git a/tests/luau/require/with_config/chained_aliases/subdirectory/inner_dependency.luau b/tests/luau/require/with_config/chained_aliases/subdirectory/inner_dependency.luau new file mode 100644 index 00000000..917d461f --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/subdirectory/inner_dependency.luau @@ -0,0 +1 @@ +return {"result from inner_dependency"} diff --git a/tests/luau/require/with_config/chained_aliases/subdirectory/successful_requirer.luau b/tests/luau/require/with_config/chained_aliases/subdirectory/successful_requirer.luau new file mode 100644 index 00000000..988e467b --- /dev/null +++ b/tests/luau/require/with_config/chained_aliases/subdirectory/successful_requirer.luau @@ -0,0 +1,7 @@ +local result = {} + +table.insert(result, require("@dep")) +table.insert(result, require("@outerdep")) +table.insert(result, require("@outerdir/outer_dependency")) + +return result diff --git a/tests/luau/require/with_config/src/.luaurc b/tests/luau/require/with_config/src/.luaurc new file mode 100644 index 00000000..27263339 --- /dev/null +++ b/tests/luau/require/with_config/src/.luaurc @@ -0,0 +1,6 @@ +{ + "aliases": { + "dep": "./dependency", + "subdir": "./subdirectory" + } +} diff --git a/tests/luau/require/with_config/src/alias_requirer.luau b/tests/luau/require/with_config/src/alias_requirer.luau new file mode 100644 index 00000000..4375a783 --- /dev/null +++ b/tests/luau/require/with_config/src/alias_requirer.luau @@ -0,0 +1 @@ +return require("@dep") diff --git a/tests/luau/require/with_config/src/alias_requirer_uc.luau b/tests/luau/require/with_config/src/alias_requirer_uc.luau new file mode 100644 index 00000000..7fa5dc9e --- /dev/null +++ b/tests/luau/require/with_config/src/alias_requirer_uc.luau @@ -0,0 +1 @@ +return require("@DeP") diff --git a/tests/luau/require/with_config/src/dependency.luau b/tests/luau/require/with_config/src/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/with_config/src/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/with_config/src/directory_alias_requirer.luau b/tests/luau/require/with_config/src/directory_alias_requirer.luau new file mode 100644 index 00000000..3b19d4ff --- /dev/null +++ b/tests/luau/require/with_config/src/directory_alias_requirer.luau @@ -0,0 +1 @@ +return(require("@subdir/subdirectory_dependency")) diff --git a/tests/luau/require/with_config/src/other_dependency.luau b/tests/luau/require/with_config/src/other_dependency.luau new file mode 100644 index 00000000..8c582dc2 --- /dev/null +++ b/tests/luau/require/with_config/src/other_dependency.luau @@ -0,0 +1 @@ +return {"result from other_dependency"} diff --git a/tests/luau/require/with_config/src/parent_alias_requirer.luau b/tests/luau/require/with_config/src/parent_alias_requirer.luau new file mode 100644 index 00000000..a8e8de09 --- /dev/null +++ b/tests/luau/require/with_config/src/parent_alias_requirer.luau @@ -0,0 +1 @@ +return require("@otherdep") diff --git a/tests/luau/require/with_config/src/subdirectory/subdirectory_dependency.luau b/tests/luau/require/with_config/src/subdirectory/subdirectory_dependency.luau new file mode 100644 index 00000000..8bbd0beb --- /dev/null +++ b/tests/luau/require/with_config/src/subdirectory/subdirectory_dependency.luau @@ -0,0 +1 @@ +return {"result from subdirectory_dependency"} diff --git a/tests/luau/require/with_config_luau/.config.luau b/tests/luau/require/with_config_luau/.config.luau new file mode 100644 index 00000000..b64979de --- /dev/null +++ b/tests/luau/require/with_config_luau/.config.luau @@ -0,0 +1,8 @@ +return { + luau = { + aliases = { + dep = "./this_should_be_overwritten_by_child_luaurc", + otherdep = "./src/other_dependency" + } + } +} diff --git a/tests/luau/require/with_config_luau/chained_aliases/.config.luau b/tests/luau/require/with_config_luau/chained_aliases/.config.luau new file mode 100644 index 00000000..fa04eb6f --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/.config.luau @@ -0,0 +1,11 @@ +return { + luau = { + aliases = { + outer = "./", + cyclicentry = "@cyclic1", + cyclic1 = "@cyclic2", + cyclic2 = "@cyclic3", + cyclic3 = "@cyclic1" + } + } +} diff --git a/tests/luau/require/with_config_luau/chained_aliases/outer_dependency.luau b/tests/luau/require/with_config_luau/chained_aliases/outer_dependency.luau new file mode 100644 index 00000000..69ffda57 --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/outer_dependency.luau @@ -0,0 +1 @@ +return {"result from outer_dependency"} diff --git a/tests/luau/require/with_config_luau/chained_aliases/subdirectory/.config.luau b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/.config.luau new file mode 100644 index 00000000..b7faaa47 --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/.config.luau @@ -0,0 +1,12 @@ +return { + luau = { + aliases = { + passthroughinner = "./inner_dependency", + passthroughouter = "@outer", + dep = "@passthroughinner", + outerdep = "@outer/outer_dependency", + outerdir = "@passthroughouter", + brokenchain = "@missing" + } + } +} diff --git a/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_cyclic.luau b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_cyclic.luau new file mode 100644 index 00000000..9f5ed488 --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_cyclic.luau @@ -0,0 +1 @@ +return require("@cyclicentry") diff --git a/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_missing.luau b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_missing.luau new file mode 100644 index 00000000..75703849 --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/failing_requirer_missing.luau @@ -0,0 +1 @@ +return require("@brokenchain") diff --git a/tests/luau/require/with_config_luau/chained_aliases/subdirectory/inner_dependency.luau b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/inner_dependency.luau new file mode 100644 index 00000000..917d461f --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/inner_dependency.luau @@ -0,0 +1 @@ +return {"result from inner_dependency"} diff --git a/tests/luau/require/with_config_luau/chained_aliases/subdirectory/successful_requirer.luau b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/successful_requirer.luau new file mode 100644 index 00000000..988e467b --- /dev/null +++ b/tests/luau/require/with_config_luau/chained_aliases/subdirectory/successful_requirer.luau @@ -0,0 +1,7 @@ +local result = {} + +table.insert(result, require("@dep")) +table.insert(result, require("@outerdep")) +table.insert(result, require("@outerdir/outer_dependency")) + +return result diff --git a/tests/luau/require/with_config_luau/src/.config.luau b/tests/luau/require/with_config_luau/src/.config.luau new file mode 100644 index 00000000..63d3bbc8 --- /dev/null +++ b/tests/luau/require/with_config_luau/src/.config.luau @@ -0,0 +1,8 @@ +return { + luau = { + aliases = { + dep = "./dependency", + subdir = "./subdirectory" + } + } +} diff --git a/tests/luau/require/with_config_luau/src/alias_requirer.luau b/tests/luau/require/with_config_luau/src/alias_requirer.luau new file mode 100644 index 00000000..4375a783 --- /dev/null +++ b/tests/luau/require/with_config_luau/src/alias_requirer.luau @@ -0,0 +1 @@ +return require("@dep") diff --git a/tests/luau/require/with_config_luau/src/alias_requirer_uc.luau b/tests/luau/require/with_config_luau/src/alias_requirer_uc.luau new file mode 100644 index 00000000..7fa5dc9e --- /dev/null +++ b/tests/luau/require/with_config_luau/src/alias_requirer_uc.luau @@ -0,0 +1 @@ +return require("@DeP") diff --git a/tests/luau/require/with_config_luau/src/dependency.luau b/tests/luau/require/with_config_luau/src/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/with_config_luau/src/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/with_config_luau/src/directory_alias_requirer.luau b/tests/luau/require/with_config_luau/src/directory_alias_requirer.luau new file mode 100644 index 00000000..3b19d4ff --- /dev/null +++ b/tests/luau/require/with_config_luau/src/directory_alias_requirer.luau @@ -0,0 +1 @@ +return(require("@subdir/subdirectory_dependency")) diff --git a/tests/luau/require/with_config_luau/src/other_dependency.luau b/tests/luau/require/with_config_luau/src/other_dependency.luau new file mode 100644 index 00000000..8c582dc2 --- /dev/null +++ b/tests/luau/require/with_config_luau/src/other_dependency.luau @@ -0,0 +1 @@ +return {"result from other_dependency"} diff --git a/tests/luau/require/with_config_luau/src/parent_alias_requirer.luau b/tests/luau/require/with_config_luau/src/parent_alias_requirer.luau new file mode 100644 index 00000000..a8e8de09 --- /dev/null +++ b/tests/luau/require/with_config_luau/src/parent_alias_requirer.luau @@ -0,0 +1 @@ +return require("@otherdep") diff --git a/tests/luau/require/with_config_luau/src/subdirectory/subdirectory_dependency.luau b/tests/luau/require/with_config_luau/src/subdirectory/subdirectory_dependency.luau new file mode 100644 index 00000000..8bbd0beb --- /dev/null +++ b/tests/luau/require/with_config_luau/src/subdirectory/subdirectory_dependency.luau @@ -0,0 +1 @@ +return {"result from subdirectory_dependency"} diff --git a/tests/luau/require/without_config/ambiguous/directory/dependency.luau b/tests/luau/require/without_config/ambiguous/directory/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/without_config/ambiguous/directory/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/without_config/ambiguous/directory/dependency/init.luau b/tests/luau/require/without_config/ambiguous/directory/dependency/init.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/without_config/ambiguous/directory/dependency/init.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/without_config/ambiguous/file/dependency.lua b/tests/luau/require/without_config/ambiguous/file/dependency.lua new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/without_config/ambiguous/file/dependency.lua @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/without_config/ambiguous/file/dependency.luau b/tests/luau/require/without_config/ambiguous/file/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/without_config/ambiguous/file/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/without_config/ambiguous_directory_requirer.luau b/tests/luau/require/without_config/ambiguous_directory_requirer.luau new file mode 100644 index 00000000..e46be806 --- /dev/null +++ b/tests/luau/require/without_config/ambiguous_directory_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/directory/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/luau/require/without_config/ambiguous_file_requirer.luau b/tests/luau/require/without_config/ambiguous_file_requirer.luau new file mode 100644 index 00000000..8e3a576d --- /dev/null +++ b/tests/luau/require/without_config/ambiguous_file_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/file/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/luau/require/without_config/dependency.luau b/tests/luau/require/without_config/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/luau/require/without_config/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/luau/require/without_config/lua/init.lua b/tests/luau/require/without_config/lua/init.lua new file mode 100644 index 00000000..7c28b735 --- /dev/null +++ b/tests/luau/require/without_config/lua/init.lua @@ -0,0 +1 @@ +return {"result from init.lua"} diff --git a/tests/luau/require/without_config/lua_dependency.lua b/tests/luau/require/without_config/lua_dependency.lua new file mode 100644 index 00000000..aec2d82b --- /dev/null +++ b/tests/luau/require/without_config/lua_dependency.lua @@ -0,0 +1 @@ +return {"result from lua_dependency"} diff --git a/tests/luau/require/without_config/luau/init.luau b/tests/luau/require/without_config/luau/init.luau new file mode 100644 index 00000000..72463463 --- /dev/null +++ b/tests/luau/require/without_config/luau/init.luau @@ -0,0 +1 @@ +return {"result from init.luau"} diff --git a/tests/luau/require/without_config/module.luau b/tests/luau/require/without_config/module.luau new file mode 100644 index 00000000..1d1393ff --- /dev/null +++ b/tests/luau/require/without_config/module.luau @@ -0,0 +1,3 @@ +local result = require("./dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/luau/require/without_config/nested/init.luau b/tests/luau/require/without_config/nested/init.luau new file mode 100644 index 00000000..75b9617d --- /dev/null +++ b/tests/luau/require/without_config/nested/init.luau @@ -0,0 +1,2 @@ +local result = require("@self/submodule") +return result diff --git a/tests/luau/require/without_config/nested/submodule.luau b/tests/luau/require/without_config/nested/submodule.luau new file mode 100644 index 00000000..9221587e --- /dev/null +++ b/tests/luau/require/without_config/nested/submodule.luau @@ -0,0 +1 @@ +return {"result from submodule"} diff --git a/tests/luau/require/without_config/nested_inits/init.luau b/tests/luau/require/without_config/nested_inits/init.luau new file mode 100644 index 00000000..9a36b68a --- /dev/null +++ b/tests/luau/require/without_config/nested_inits/init.luau @@ -0,0 +1,2 @@ +local result = require("@self/init") +return result diff --git a/tests/luau/require/without_config/nested_inits/init/init.luau b/tests/luau/require/without_config/nested_inits/init/init.luau new file mode 100644 index 00000000..0623c941 --- /dev/null +++ b/tests/luau/require/without_config/nested_inits/init/init.luau @@ -0,0 +1 @@ +return {"result from nested_inits/init"} diff --git a/tests/luau/require/without_config/nested_inits_requirer.luau b/tests/luau/require/without_config/nested_inits_requirer.luau new file mode 100644 index 00000000..6c4a0a5f --- /dev/null +++ b/tests/luau/require/without_config/nested_inits_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./nested_inits") +result[#result+1] = "required into module" +return result diff --git a/tests/luau/require/without_config/nested_module_requirer.luau b/tests/luau/require/without_config/nested_module_requirer.luau new file mode 100644 index 00000000..fc8d5e79 --- /dev/null +++ b/tests/luau/require/without_config/nested_module_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./nested") +result[#result+1] = "required into module" +return result diff --git a/tests/luau/require/without_config/validate_cache.luau b/tests/luau/require/without_config/validate_cache.luau new file mode 100644 index 00000000..dad139b3 --- /dev/null +++ b/tests/luau/require/without_config/validate_cache.luau @@ -0,0 +1,4 @@ +local result1 = require("./dependency") +local result2 = require("./dependency") +assert(result1 == result2, "expect the same result when requiring the same module twice") +return {} \ No newline at end of file diff --git a/tests/memory.rs b/tests/memory.rs index e359a5da..4f91221e 100644 --- a/tests/memory.rs +++ b/tests/memory.rs @@ -1,11 +1,11 @@ use std::sync::Arc; -use mlua::{GCMode, Lua, Result, UserData}; +use mlua::state::{GcIncParams, GcMode}; +use mlua::{Error, Lua, Result, UserData}; -#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] -use mlua::Error; +#[cfg(any(feature = "lua54", feature = "lua55"))] +use mlua::state::GcGenParams; -#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] #[test] fn test_memory_limit() -> Result<()> { let lua = Lua::new(); @@ -19,16 +19,54 @@ fn test_memory_limit() -> Result<()> { let f = lua .load("local t = {}; for i = 1,10000 do t[i] = i end") .into_function()?; - f.call::<_, ()>(()).expect("should trigger no memory limit"); + f.call::<()>(()).expect("should trigger no memory limit"); + + if cfg!(feature = "luajit") && lua.set_memory_limit(0).is_err() { + // seems this luajit version does not support memory limit + return Ok(()); + } lua.set_memory_limit(initial_memory + 10000)?; - match f.call::<_, ()>(()) { + match f.call::<()>(()) { Err(Error::MemoryError(_)) => {} something_else => panic!("did not trigger memory error: {:?}", something_else), }; lua.set_memory_limit(0)?; - f.call::<_, ()>(()).expect("should trigger no memory limit"); + f.call::<()>(()).expect("should trigger no memory limit"); + + // Test memory limit during chunk loading + lua.set_memory_limit(1024)?; + match lua + .load("local t = {}; for i = 1,10000 do t[i] = i end") + .into_function() + { + Err(Error::MemoryError(_)) => {} + _ => panic!("did not trigger memory error"), + }; + + Ok(()) +} + +#[test] +fn test_memory_limit_thread() -> Result<()> { + let lua = Lua::new(); + + let f = lua + .load("local t = {}; for i = 1,10000 do t[i] = i end") + .into_function()?; + + if cfg!(feature = "luajit") && lua.set_memory_limit(0).is_err() { + // seems this luajit version does not support memory limit + return Ok(()); + } + + let thread = lua.create_thread(f)?; + lua.set_memory_limit(lua.used_memory() + 10000)?; + match thread.resume::<()>(()) { + Err(Error::MemoryError(_)) => {} + something_else => panic!("did not trigger memory error: {:?}", something_else), + }; Ok(()) } @@ -38,13 +76,20 @@ fn test_gc_control() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); - #[cfg(feature = "lua54")] + #[cfg(any(feature = "lua55", feature = "lua54"))] { - assert_eq!(lua.gc_gen(0, 0), GCMode::Incremental); - assert_eq!(lua.gc_inc(0, 0, 0), GCMode::Generational); + assert!(matches!( + lua.gc_set_mode(GcMode::Generational(GcGenParams::default())), + GcMode::Incremental(_) + )); + assert!(matches!( + lua.gc_set_mode(GcMode::Incremental(GcIncParams::default())), + GcMode::Generational(_) + )); } #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -58,9 +103,19 @@ fn test_gc_control() -> Result<()> { assert!(lua.gc_is_running()); } - assert_eq!(lua.gc_inc(200, 100, 13), GCMode::Incremental); - - struct MyUserdata(Arc<()>); + assert!(matches!( + lua.gc_set_mode(GcMode::Incremental({ + let p = GcIncParams::default().step_multiplier(100); + #[cfg(not(feature = "luau"))] + let p = p.pause(200); + #[cfg(feature = "luau")] + let p = p.goal(200); + p + })), + GcMode::Incremental(_) + )); + + struct MyUserdata(#[allow(unused)] Arc<()>); impl UserData for MyUserdata {} let rc = Arc::new(()); diff --git a/tests/module/.cargo/config b/tests/module/.cargo/config deleted file mode 100644 index d47f983e..00000000 --- a/tests/module/.cargo/config +++ /dev/null @@ -1,11 +0,0 @@ -[target.x86_64-apple-darwin] -rustflags = [ - "-C", "link-arg=-undefined", - "-C", "link-arg=dynamic_lookup", -] - -[target.aarch64-apple-darwin] -rustflags = [ - "-C", "link-arg=-undefined", - "-C", "link-arg=dynamic_lookup", -] diff --git a/tests/module/Cargo.toml b/tests/module/Cargo.toml index adf98d9f..9cf87683 100644 --- a/tests/module/Cargo.toml +++ b/tests/module/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "rust_module" +name = "test_module" version = "0.0.0" authors = ["Aleksandr Orlenko "] -edition = "2018" +edition = "2021" [lib] crate-type = ["cdylib"] @@ -13,6 +13,7 @@ members = [ ] [features] +lua55 = ["mlua/lua55"] lua54 = ["mlua/lua54"] lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] diff --git a/tests/module/build.rs b/tests/module/build.rs new file mode 100644 index 00000000..eb5ca3c0 --- /dev/null +++ b/tests/module/build.rs @@ -0,0 +1,7 @@ +fn main() { + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-cdylib-link-arg=-undefined"); + println!("cargo:rustc-cdylib-link-arg=dynamic_lookup"); + } +} diff --git a/tests/module/loader/.cargo/config b/tests/module/loader/.cargo/config.toml similarity index 100% rename from tests/module/loader/.cargo/config rename to tests/module/loader/.cargo/config.toml diff --git a/tests/module/loader/Cargo.toml b/tests/module/loader/Cargo.toml index 180b2e94..ddd6123f 100644 --- a/tests/module/loader/Cargo.toml +++ b/tests/module/loader/Cargo.toml @@ -2,9 +2,10 @@ name = "module_loader" version = "0.0.0" authors = ["Aleksandr Orlenko "] -edition = "2018" +edition = "2021" [features] +lua55 = ["mlua/lua55"] lua54 = ["mlua/lua54"] lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] diff --git a/tests/module/loader/tests/load.rs b/tests/module/loader/tests/load.rs index 15dc5ae6..873da7d8 100644 --- a/tests/module/loader/tests/load.rs +++ b/tests/module/loader/tests/load.rs @@ -4,11 +4,11 @@ use std::path::PathBuf; use mlua::{Lua, Result}; #[test] -fn test_module() -> Result<()> { +fn test_module_simple() -> Result<()> { let lua = make_lua()?; lua.load( r#" - local mod = require("rust_module") + local mod = require("test_module") assert(mod.sum(2,2) == 4) "#, ) @@ -20,8 +20,8 @@ fn test_module_multi() -> Result<()> { let lua = make_lua()?; lua.load( r#" - local mod = require("rust_module") - local mod2 = require("rust_module.second") + local mod = require("test_module") + local mod2 = require("test_module.second") assert(mod.check_userdata(mod2.userdata) == 123) "#, ) @@ -33,7 +33,7 @@ fn test_module_error() -> Result<()> { let lua = make_lua()?; lua.load( r#" - local ok, err = pcall(require, "rust_module.error") + local ok, err = pcall(require, "test_module.error") assert(not ok) assert(string.find(tostring(err), "custom module error")) "#, @@ -42,6 +42,7 @@ fn test_module_error() -> Result<()> { } #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -55,11 +56,12 @@ fn test_module_from_thread() -> Result<()> { local mod local co = coroutine.create(function(a, b) - mod = require("rust_module") + mod = require("test_module") assert(mod.sum(a, b) == a + b) end) - coroutine.resume(co, 3, 5) + local ok, err = coroutine.resume(co, 3, 5) + assert(ok, err) collectgarbage() assert(mod.used_memory() > 0) @@ -68,6 +70,42 @@ fn test_module_from_thread() -> Result<()> { .exec() } +#[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "lua51" +))] +#[test] +fn test_module_multi_from_thread() -> Result<()> { + let lua = make_lua()?; + lua.load( + r#" + local mod = require("test_module") + local co = coroutine.create(function() + local mod2 = require("test_module.second") + assert(mod2.userdata ~= nil) + end) + local ok, err = coroutine.resume(co) + assert(ok, err) + "#, + ) + .exec() +} + +#[test] +fn test_module_new_vm() -> Result<()> { + let lua = make_lua()?; + lua.load( + r#" + local mod = require("test_module.new_vm") + assert(mod.eval("return \"hello, world\"") == "hello, world") + "#, + ) + .exec() +} + fn make_lua() -> Result { let (dylib_path, dylib_ext, separator); if cfg!(target_os = "macos") { diff --git a/tests/module/src/lib.rs b/tests/module/src/lib.rs index fdaf1716..ac505824 100644 --- a/tests/module/src/lib.rs +++ b/tests/module/src/lib.rs @@ -8,12 +8,12 @@ fn used_memory(lua: &Lua, _: ()) -> LuaResult { Ok(lua.used_memory()) } -fn check_userdata(_: &Lua, ud: MyUserData) -> LuaResult { - Ok(ud.0) +fn check_userdata(_: &Lua, ud: LuaAnyUserData) -> LuaResult { + Ok(ud.borrow::()?.0) } #[mlua::lua_module] -fn rust_module(lua: &Lua) -> LuaResult { +fn test_module(lua: &Lua) -> LuaResult { let exports = lua.create_table()?; exports.set("sum", lua.create_function(sum)?)?; exports.set("used_memory", lua.create_function(used_memory)?)?; @@ -26,14 +26,26 @@ struct MyUserData(i32); impl LuaUserData for MyUserData {} +#[mlua::lua_module(name = "test_module_second", skip_memory_check)] +fn test_module2(lua: &Lua) -> LuaResult { + let exports = lua.create_table()?; + exports.set("userdata", MyUserData(123))?; + Ok(exports) +} + #[mlua::lua_module] -fn rust_module_second(lua: &Lua) -> LuaResult { +fn test_module_new_vm(lua: &Lua) -> LuaResult { + let eval = lua.create_function(|_, prog: String| { + let lua = Lua::new(); + lua.load(prog).eval::>() + })?; + let exports = lua.create_table()?; - exports.set("userdata", lua.create_userdata(MyUserData(123))?)?; + exports.set("eval", eval)?; Ok(exports) } #[mlua::lua_module] -fn rust_module_error(_: &Lua) -> LuaResult { - Err("custom module error".to_lua_err()) +fn test_module_error(_: &Lua) -> LuaResult { + Err("custom module error".into_lua_err()) } diff --git a/tests/multi.rs b/tests/multi.rs new file mode 100644 index 00000000..9fe43bfc --- /dev/null +++ b/tests/multi.rs @@ -0,0 +1,107 @@ +use mlua::{ + Error, ExternalError, Integer, IntoLuaMulti, Lua, LuaString, MultiValue, Result, Value, Variadic, +}; + +#[test] +fn test_result_conversions() -> Result<()> { + let lua = Lua::new(); + let globals = lua.globals(); + + let ok = lua.create_function(|_, ()| Ok(Ok::<(), Error>(())))?; + let err = lua.create_function(|_, ()| Ok(Err::<(), _>("failure1".into_lua_err())))?; + let ok2 = lua.create_function(|_, ()| Ok(Ok::<_, Error>("!".to_owned())))?; + let err2 = lua.create_function(|_, ()| Ok(Err::("failure2".into_lua_err())))?; + + globals.set("ok", ok)?; + globals.set("ok2", ok2)?; + globals.set("err", err)?; + globals.set("err2", err2)?; + + lua.load( + r#" + local r, e = ok() + assert(r == nil and e == nil) + + local r, e = err() + assert(r == nil) + assert(tostring(e):find("failure1") ~= nil) + + local r, e = ok2() + assert(r == "!") + assert(e == nil) + + local r, e = err2() + assert(r == nil) + assert(tostring(e):find("failure2") ~= nil) + "#, + ) + .exec()?; + + // Try to convert Result into MultiValue + let ok1 = Ok::<(), Error>(()); + let multi_ok1 = ok1.into_lua_multi(&lua)?; + assert_eq!(multi_ok1.len(), 0); + let err1 = Err::<(), _>("failure1"); + let multi_err1 = err1.into_lua_multi(&lua)?; + assert_eq!(multi_err1.len(), 2); + assert_eq!(multi_err1[0], Value::Nil); + assert_eq!(multi_err1[1].as_string().unwrap(), "failure1"); + + let ok2 = Ok::<_, Error>("!"); + let multi_ok2 = ok2.into_lua_multi(&lua)?; + assert_eq!(multi_ok2.len(), 1); + assert_eq!(multi_ok2[0].as_string().unwrap(), "!"); + let err2 = Err::("failure2".into_lua_err()); + let multi_err2 = err2.into_lua_multi(&lua)?; + assert_eq!(multi_err2.len(), 2); + assert_eq!(multi_err2[0], Value::Nil); + assert!(matches!(multi_err2[1], Value::Error(_))); + assert_eq!(multi_err2[1].to_string()?, "failure2"); + + Ok(()) +} + +#[test] +fn test_multivalue() { + let mut multi = MultiValue::with_capacity(3); + multi.push_back(Value::Integer(1)); + multi.push_back(Value::Integer(2)); + multi.push_front(Value::Integer(3)); + assert_eq!(multi.iter().filter_map(|v| v.as_integer()).sum::(), 6); + + let vec = multi.into_vec(); + assert_eq!(&vec, &[Value::Integer(3), Value::Integer(1), Value::Integer(2)]); + let _multi2 = MultiValue::from_vec(vec); +} + +#[test] +fn test_multivalue_by_ref() -> Result<()> { + let lua = Lua::new(); + let multi = MultiValue::from_vec(vec![ + Value::Integer(3), + Value::String(lua.create_string("hello")?), + Value::Boolean(true), + ]); + + let f = lua.create_function(|_, (i, s, b): (i32, LuaString, bool)| { + assert_eq!(i, 3); + assert_eq!(s.to_str()?, "hello"); + assert_eq!(b, true); + Ok(()) + })?; + f.call::<()>(&multi)?; + + Ok(()) +} + +#[test] +fn test_variadic() { + let mut var = Variadic::with_capacity(3); + var.extend_from_slice(&[1, 2, 3]); + assert_eq!(var.iter().sum::(), 6); + + let vec = Vec::::from(var); + assert_eq!(&vec, &[1, 2, 3]); + let var2 = Variadic::from(vec); + assert_eq!(var2.as_slice(), &[1, 2, 3]); +} diff --git a/tests/scope.rs b/tests/scope.rs index 96403839..9b16fcdd 100644 --- a/tests/scope.rs +++ b/tests/scope.rs @@ -3,8 +3,8 @@ use std::rc::Rc; use std::sync::Arc; use mlua::{ - AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields, - UserDataMethods, + AnyUserData, Error, Function, Lua, LuaString, MetaMethod, ObjectLike, Result, UserData, UserDataFields, + UserDataMethods, UserDataRegistry, }; #[test] @@ -13,20 +13,20 @@ fn test_scope_func() -> Result<()> { let rc = Rc::new(Cell::new(0)); lua.scope(|scope| { - let r = rc.clone(); + let rc2 = rc.clone(); let f = scope.create_function(move |_, ()| { - r.set(42); + rc2.set(42); Ok(()) })?; - lua.globals().set("bad", f.clone())?; - f.call::<_, ()>(())?; + lua.globals().set("f", &f)?; + f.call::<()>(())?; assert_eq!(Rc::strong_count(&rc), 2); Ok(()) })?; assert_eq!(rc.get(), 42); assert_eq!(Rc::strong_count(&rc), 1); - match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) { + match lua.globals().get::("f")?.call::<()>(()) { Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { Error::CallbackDestructed => {} ref err => panic!("wrong error type {:?}", err), @@ -48,7 +48,7 @@ fn test_scope_capture() -> Result<()> { i = 42; Ok(()) })? - .call::<_, ()>(()) + .call::<()>(()) })?; assert_eq!(i, 42); @@ -60,12 +60,29 @@ fn test_scope_outer_lua_access() -> Result<()> { let lua = Lua::new(); let table = lua.create_table()?; + lua.scope(|scope| scope.create_function(|_, ()| table.set("a", "b"))?.call::<()>(()))?; + assert_eq!(table.get::("a")?, "b"); + + Ok(()) +} + +#[test] +fn test_scope_capture_scope() -> Result<()> { + let lua = Lua::new(); + + let i = Cell::new(0); lua.scope(|scope| { - scope - .create_function_mut(|_, ()| table.set("a", "b"))? - .call::<_, ()>(()) + let f = scope.create_function(|_, ()| { + scope.create_function(|_, n: u32| { + i.set(i.get() + n); + Ok(()) + }) + })?; + f.call::(())?.call::<()>(10)?; + Ok(()) })?; - assert_eq!(table.get::<_, String>("a")?, "b"); + + assert_eq!(i.get(), 10); Ok(()) } @@ -74,10 +91,11 @@ fn test_scope_outer_lua_access() -> Result<()> { fn test_scope_userdata_fields() -> Result<()> { struct MyUserData<'a>(&'a Cell); - impl<'a> UserData for MyUserData<'a> { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - fields.add_field_method_get("val", |_, data| Ok(data.0.get())); - fields.add_field_method_set("val", |_, data, val| { + impl UserData for MyUserData<'_> { + fn register(reg: &mut UserDataRegistry) { + reg.add_field("field", "hello"); + reg.add_field_method_get("val", |_, data| Ok(data.0.get())); + reg.add_field_method_set("val", |_, data, val| { data.0.set(val); Ok(()) }); @@ -91,6 +109,7 @@ fn test_scope_userdata_fields() -> Result<()> { .load( r#" function(u) + assert(u.field == "hello") assert(u.val == 42) u.val = 44 end @@ -98,7 +117,7 @@ fn test_scope_userdata_fields() -> Result<()> { ) .eval()?; - lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&i))?))?; + lua.scope(|scope| f.call::<()>(scope.create_userdata(MyUserData(&i))?))?; assert_eq!(i.get(), 44); @@ -109,14 +128,14 @@ fn test_scope_userdata_fields() -> Result<()> { fn test_scope_userdata_methods() -> Result<()> { struct MyUserData<'a>(&'a Cell); - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("inc", |_, data, ()| { + impl UserData for MyUserData<'_> { + fn register(reg: &mut UserDataRegistry) { + reg.add_method("inc", |_, data, ()| { data.0.set(data.0.get() + 1); Ok(()) }); - methods.add_method("dec", |_, data, ()| { + reg.add_method("dec", |_, data, ()| { data.0.set(data.0.get() - 1); Ok(()) }); @@ -139,7 +158,7 @@ fn test_scope_userdata_methods() -> Result<()> { ) .eval()?; - lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&i))?))?; + lua.scope(|scope| f.call::<()>(scope.create_userdata(MyUserData(&i))?))?; assert_eq!(i.get(), 44); @@ -147,19 +166,19 @@ fn test_scope_userdata_methods() -> Result<()> { } #[test] -fn test_scope_userdata_functions() -> Result<()> { +fn test_scope_userdata_ops() -> Result<()> { struct MyUserData<'a>(&'a i64); - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_meta_function(MetaMethod::Add, |lua, ()| { + impl UserData for MyUserData<'_> { + fn register(reg: &mut UserDataRegistry) { + reg.add_meta_method(MetaMethod::Add, |lua, this, ()| { let globals = lua.globals(); - globals.set("i", globals.get::<_, i64>("i")? + 1)?; + globals.set("i", globals.get::("i")? + this.0)?; Ok(()) }); - methods.add_meta_function(MetaMethod::Sub, |lua, ()| { + reg.add_meta_method(MetaMethod::Sub, |lua, this, ()| { let globals = lua.globals(); - globals.set("i", globals.get::<_, i64>("i")? + 1)?; + globals.set("i", globals.get::("i")? + this.0)?; Ok(()) }); } @@ -167,7 +186,7 @@ fn test_scope_userdata_functions() -> Result<()> { let lua = Lua::new(); - let dummy = 0; + let dummy = 1; let f = lua .load( r#" @@ -175,27 +194,54 @@ fn test_scope_userdata_functions() -> Result<()> { return function(u) _ = u + u _ = u - 1 - _ = 1 + u + _ = u + 1 end "#, ) .eval::()?; - lua.scope(|scope| f.call::<_, ()>(scope.create_nonstatic_userdata(MyUserData(&dummy))?))?; + lua.scope(|scope| f.call::<()>(scope.create_userdata(MyUserData(&dummy))?))?; + + assert_eq!(lua.globals().get::("i")?, 3); + + Ok(()) +} + +#[test] +fn test_scope_userdata_values() -> Result<()> { + struct MyUserData<'a>(&'a i64); - assert_eq!(lua.globals().get::<_, i64>("i")?, 3); + impl UserData for MyUserData<'_> { + fn register(registry: &mut UserDataRegistry) { + registry.add_method("get", |_, data, ()| Ok(*data.0)); + } + } + + let lua = Lua::new(); + + let i = 42; + let data = MyUserData(&i); + lua.scope(|scope| { + let ud = scope.create_userdata(data)?; + assert_eq!(ud.call_method::("get", &ud)?, 42); + ud.set_user_value("user_value")?; + assert_eq!(ud.user_value::()?, "user_value"); + Ok(()) + })?; Ok(()) } #[test] fn test_scope_userdata_mismatch() -> Result<()> { - struct MyUserData<'a>(&'a Cell); + struct MyUserData<'a>(&'a mut i64); impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("inc", |_, data, ()| { - data.0.set(data.0.get() + 1); + fn register(reg: &mut UserDataRegistry) { + reg.add_method("get", |_, data, ()| Ok(*data.0)); + + reg.add_method_mut("inc", |_, data, ()| { + *data.0 = data.0.wrapping_add(1); Ok(()) }); } @@ -205,34 +251,54 @@ fn test_scope_userdata_mismatch() -> Result<()> { lua.load( r#" - function okay(a, b) - a.inc(a) - b.inc(b) - end - function bad(a, b) - a.inc(b) - end + function inc(a, b) a.inc(b) end + function get(a, b) a.get(b) end "#, ) .exec()?; - let a = Cell::new(1); - let b = Cell::new(1); - - let okay: Function = lua.globals().get("okay")?; - let bad: Function = lua.globals().get("bad")?; + let mut a = 1; + let mut b = 1; lua.scope(|scope| { - let au = scope.create_nonstatic_userdata(MyUserData(&a))?; - let bu = scope.create_nonstatic_userdata(MyUserData(&b))?; - assert!(okay.call::<_, ()>((au.clone(), bu.clone())).is_ok()); - match bad.call::<_, ()>((au, bu)) { - Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { - Error::UserDataTypeMismatch => {} - ref other => panic!("wrong error type {:?}", other), - }, - Err(other) => panic!("wrong error type {:?}", other), - Ok(_) => panic!("incorrectly returned Ok"), + let au = scope.create_userdata(MyUserData(&mut a))?; + let bu = scope.create_userdata(MyUserData(&mut b))?; + for method_name in ["get", "inc"] { + let f: Function = lua.globals().get(method_name)?; + let full_name = format!("MyUserData.{method_name}"); + let full_name = full_name.as_str(); + + assert!(f.call::<()>((&au, &au)).is_ok()); + match f.call::<()>((&au, &bu)) { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::BadArgument { to, pos, name, cause } => { + assert_eq!(to.as_deref(), Some(full_name)); + assert_eq!(*pos, 1); + assert_eq!(name.as_deref(), Some("self")); + assert!(matches!(*cause.as_ref(), Error::UserDataTypeMismatch)); + } + other => panic!("wrong error type {other:?}"), + }, + Err(other) => panic!("wrong error type {other:?}"), + Ok(_) => panic!("incorrectly returned Ok"), + } + + // Pass non-userdata type + let err = f.call::<()>((&au, 321)).err().unwrap(); + match err { + Error::CallbackError { ref cause, .. } => match cause.as_ref() { + Error::BadArgument { to, pos, name, cause } => { + assert_eq!(to.as_deref(), Some(full_name)); + assert_eq!(*pos, 1); + assert_eq!(name.as_deref(), Some("self")); + assert!(matches!(*cause.as_ref(), Error::FromLuaConversionError { .. })); + } + other => panic!("wrong error type {other:?}"), + }, + other => panic!("wrong error type {other:?}"), + } + let err_msg = format!("bad argument `self` to `{full_name}`: error converting Lua number to userdata (expected userdata of type 'MyUserData')"); + assert!(err.to_string().contains(&err_msg)); } Ok(()) })?; @@ -244,115 +310,262 @@ fn test_scope_userdata_mismatch() -> Result<()> { fn test_scope_userdata_drop() -> Result<()> { let lua = Lua::new(); - struct MyUserData(Rc<()>); + struct MyUserData<'a>(&'a Cell, #[allow(unused)] Rc<()>); - impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("method", |_, _, ()| Ok(())); + impl UserData for MyUserData<'_> { + fn register(reg: &mut UserDataRegistry) { + reg.add_method("inc", |_, data, ()| { + data.0.set(data.0.get() + 1); + Ok(()) + }); } } - struct MyUserDataArc(Arc<()>); - - impl UserData for MyUserDataArc {} - - let rc = Rc::new(()); - let arc = Arc::new(()); + let (i, rc) = (Cell::new(1), Rc::new(())); lua.scope(|scope| { - let ud = scope.create_userdata(MyUserData(rc.clone()))?; - ud.set_user_value(MyUserDataArc(arc.clone()))?; + let ud = scope.create_userdata(MyUserData(&i, rc.clone()))?; lua.globals().set("ud", ud)?; + lua.load("ud:inc()").exec()?; assert_eq!(Rc::strong_count(&rc), 2); - assert_eq!(Arc::strong_count(&arc), 2); Ok(()) })?; - - lua.gc_collect()?; assert_eq!(Rc::strong_count(&rc), 1); - assert_eq!(Arc::strong_count(&arc), 1); + assert_eq!(i.get(), 2); - match lua.load("ud:method()").exec() { + match lua.load("ud:inc()").exec() { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - err => panic!("expected CallbackDestructed, got {:?}", err), + Error::UserDataDestructed => {} + err => panic!("expected UserDataDestructed, got {err:?}"), }, - r => panic!("improper return for destructed userdata: {:?}", r), + r => panic!("improper return for destructed userdata: {r:?}"), }; - let ud = lua.globals().get::<_, AnyUserData>("ud")?; - match ud.borrow::() { - Ok(_) => panic!("succesfull borrow for destructed userdata"), + let ud = lua.globals().get::("ud")?; + match ud.borrow_scoped::(|_| Ok::<_, Error>(())) { + Ok(_) => panic!("successful borrow for destructed userdata"), Err(Error::UserDataDestructed) => {} - Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err), + Err(err) => panic!("improper borrow error for destructed userdata: {err:?}"), } - - match ud.get_metatable() { + match ud.metatable() { Ok(_) => panic!("successful metatable retrieval of destructed userdata"), Err(Error::UserDataDestructed) => {} - Err(err) => panic!( - "improper metatable error for destructed userdata: {:?}", - err - ), + Err(err) => panic!("improper metatable error for destructed userdata: {err:?}"), } Ok(()) } #[test] -fn test_scope_nonstatic_userdata_drop() -> Result<()> { +fn test_scope_userdata_ref() -> Result<()> { let lua = Lua::new(); - struct MyUserData<'a>(&'a Cell, Arc<()>); + struct MyUserData(Cell); - impl<'a> UserData for MyUserData<'a> { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { methods.add_method("inc", |_, data, ()| { data.0.set(data.0.get() + 1); Ok(()) }); + + methods.add_method("dec", |_, data, ()| { + data.0.set(data.0.get() - 1); + Ok(()) + }); } } - struct MyUserDataArc(Arc<()>); + let data = MyUserData(Cell::new(1)); + lua.scope(|scope| { + let ud = scope.create_userdata_ref(&data)?; + modify_userdata(&lua, &ud)?; + + // We can only borrow userdata scoped + #[rustfmt::skip] + assert!(matches!(ud.borrow::(), Err(Error::UserDataTypeMismatch))); + ud.borrow_scoped::(|ud_inst| { + assert_eq!(ud_inst.0.get(), 2); + })?; - impl UserData for MyUserDataArc {} + Ok(()) + })?; + assert_eq!(data.0.get(), 2); + + Ok(()) +} + +#[test] +fn test_scope_userdata_ref_mut() -> Result<()> { + let lua = Lua::new(); + + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { + methods.add_method_mut("inc", |_, data, ()| { + data.0 += 1; + Ok(()) + }); + + methods.add_method_mut("dec", |_, data, ()| { + data.0 -= 1; + Ok(()) + }); + } + } - let i = Cell::new(1); - let arc = Arc::new(()); + let mut data = MyUserData(1); lua.scope(|scope| { - let ud = scope.create_nonstatic_userdata(MyUserData(&i, arc.clone()))?; - ud.set_user_value(MyUserDataArc(arc.clone()))?; - lua.globals().set("ud", ud)?; - lua.load("ud:inc()").exec()?; - assert_eq!(Arc::strong_count(&arc), 3); + let ud = scope.create_userdata_ref_mut(&mut data)?; + modify_userdata(&lua, &ud)?; + + #[rustfmt::skip] + assert!(matches!(ud.borrow_mut::(), Err(Error::UserDataTypeMismatch))); + ud.borrow_mut_scoped::(|ud_inst| { + ud_inst.0 += 10; + })?; + Ok(()) })?; + assert_eq!(data.0, 12); - lua.gc_collect()?; - assert_eq!(Arc::strong_count(&arc), 1); + Ok(()) +} - match lua.load("ud:inc()").exec() { +#[test] +fn test_scope_any_userdata() -> Result<()> { + let lua = Lua::new(); + + fn register(reg: &mut UserDataRegistry<&mut String>) { + reg.add_method_mut("push", |_, this, s: LuaString| { + this.push_str(&s.to_str()?); + Ok(()) + }); + reg.add_meta_method("__tostring", |_, data, ()| Ok((*data).clone())); + } + + let mut data = String::from("foo"); + lua.scope(|scope| { + let ud = scope.create_any_userdata(&mut data, register)?; + lua.globals().set("ud", ud)?; + lua.load( + r#" + assert(tostring(ud) == "foo") + ud:push("bar") + assert(tostring(ud) == "foobar") + "#, + ) + .exec() + })?; + + // Check that userdata is destructed + match lua.load("tostring(ud)").exec() { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - err => panic!("expected CallbackDestructed, got {:?}", err), + Error::UserDataDestructed => {} + err => panic!("expected CallbackDestructed, got {err:?}"), }, - r => panic!("improper return for destructed userdata: {:?}", r), + r => panic!("improper return for destructed userdata: {r:?}"), }; - let ud = lua.globals().get::<_, AnyUserData>("ud")?; - match ud.borrow::() { - Ok(_) => panic!("succesfull borrow for destructed userdata"), - Err(Error::UserDataDestructed) => {} - Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err), - } - match ud.get_metatable() { - Ok(_) => panic!("successful metatable retrieval of destructed userdata"), - Err(Error::UserDataDestructed) => {} - Err(err) => panic!( - "improper metatable error for destructed userdata: {:?}", - err - ), - } + Ok(()) +} + +#[test] +fn test_scope_any_userdata_ref() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::>(|reg| { + reg.add_method("inc", |_, data, ()| { + data.set(data.get() + 1); + Ok(()) + }); + + reg.add_method("dec", |_, data, ()| { + data.set(data.get() - 1); + Ok(()) + }); + })?; + + let data = Cell::new(1i64); + lua.scope(|scope| { + let ud = scope.create_any_userdata_ref(&data)?; + modify_userdata(&lua, &ud) + })?; + assert_eq!(data.get(), 2); Ok(()) } + +#[test] +fn test_scope_any_userdata_ref_mut() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::(|reg| { + reg.add_method_mut("inc", |_, data, ()| { + *data += 1; + Ok(()) + }); + + reg.add_method_mut("dec", |_, data, ()| { + *data -= 1; + Ok(()) + }); + })?; + + let mut data = 1i64; + lua.scope(|scope| { + let ud = scope.create_any_userdata_ref_mut(&mut data)?; + modify_userdata(&lua, &ud) + })?; + assert_eq!(data, 2); + + Ok(()) +} + +#[test] +fn test_scope_destructors() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::>(|reg| { + reg.add_meta_method("__tostring", |_, data, ()| Ok(data.to_string())); + })?; + + let arc_str = Arc::new(String::from("foo")); + + let ud = lua.create_any_userdata(arc_str.clone())?; + lua.scope(|scope| { + scope.add_destructor(|| { + assert!(ud.destroy().is_ok()); + }); + Ok(()) + })?; + assert_eq!(Arc::strong_count(&arc_str), 1); + + // Try destructing the userdata while it's borrowed + let ud = lua.create_any_userdata(arc_str.clone())?; + ud.borrow_scoped::, _>(|arc_str| { + assert_eq!(arc_str.as_str(), "foo"); + lua.scope(|scope| { + scope.add_destructor(|| { + assert!(ud.destroy().is_err()); + }); + Ok(()) + }) + .unwrap(); + assert_eq!(arc_str.as_str(), "foo"); + })?; + + Ok(()) +} + +fn modify_userdata(lua: &Lua, ud: &AnyUserData) -> Result<()> { + lua.load( + r#" + local u = ... + u:inc() + u:dec() + u:inc() +"#, + ) + .call(ud) +} diff --git a/tests/send.rs b/tests/send.rs new file mode 100644 index 00000000..2f10466a --- /dev/null +++ b/tests/send.rs @@ -0,0 +1,41 @@ +#![cfg(feature = "send")] + +use mlua::{AnyUserData, Lua, ObjectLike, Result, UserData, UserDataMethods, UserDataRef}; +use static_assertions::assert_impl_all; + +#[test] +fn test_userdata_multithread_access_sync() -> Result<()> { + let lua = Lua::new(); + + // This type is `Send` and `Sync`. + struct MyUserData(String); + assert_impl_all!(MyUserData: Send, Sync); + + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { + methods.add_method("method", |lua, this, ()| { + let ud = lua.globals().get::("ud")?; + assert!(ud.call_method::<()>("method2", ()).is_ok()); + Ok(this.0.clone()) + }); + + methods.add_method("method2", |_, _, ()| Ok(())); + } + } + + lua.globals().set("ud", MyUserData("hello".to_string()))?; + + // We acquired the shared reference. + let _ud = lua.globals().get::>("ud")?; + + std::thread::scope(|s| { + s.spawn(|| { + // Getting another shared reference for `Sync` type is allowed. + let _ = lua.globals().get::>("ud").unwrap(); + }); + }); + + lua.load("ud:method()").exec().unwrap(); + + Ok(()) +} diff --git a/tests/serde.rs b/tests/serde.rs index db0d6739..d3cc00ed 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -1,15 +1,17 @@ -#![cfg(feature = "serialize")] +#![cfg(feature = "serde")] use std::collections::HashMap; +use std::error::Error as StdError; +use bstr::BString; use mlua::{ - DeserializeOptions, Error, Lua, LuaSerdeExt, Result as LuaResult, SerializeOptions, UserData, - Value, + AnyUserData, DeserializeOptions, Error, ExternalResult, IntoLua, Lua, LuaSerdeExt, Result as LuaResult, + SerializeOptions, UserData, Value, }; use serde::{Deserialize, Serialize}; #[test] -fn test_serialize() -> Result<(), Box> { +fn test_serialize() -> Result<(), Box> { #[derive(Serialize)] struct MyUserData(i64, String); @@ -23,7 +25,7 @@ fn test_serialize() -> Result<(), Box> { globals.set("null", lua.null())?; let empty_array = lua.create_table()?; - empty_array.set_metatable(Some(lua.array_metatable())); + empty_array.set_metatable(Some(lua.array_metatable()))?; globals.set("empty_array", empty_array)?; let val = lua @@ -34,7 +36,7 @@ fn test_serialize() -> Result<(), Box> { _integer = 123, _number = 321.99, _string = "test string serialization", - _table_arr = {nil, "value 1", nil, "value 2", {}}, + _table_arr = {null, "value 1", 2, "value 3", {}}, _table_map = {["table"] = "map", ["null"] = null}, _bytes = "\240\040\140\040", _userdata = ud, @@ -51,7 +53,7 @@ fn test_serialize() -> Result<(), Box> { "_integer": 123, "_number": 321.99, "_string": "test string serialization", - "_table_arr": [null, "value 1", null, "value 2", {}], + "_table_arr": [null, "value 1", 2, "value 3", {}], "_table_map": {"table": "map", "null": null}, "_bytes": [240, 40, 140, 40], "_userdata": [123, "test userdata"], @@ -71,51 +73,34 @@ fn test_serialize() -> Result<(), Box> { } #[test] -fn test_serialize_in_scope() -> LuaResult<()> { - #[derive(Serialize, Clone)] - struct MyUserData(i64, String); - - impl UserData for MyUserData {} - +fn test_serialize_any_userdata() { let lua = Lua::new(); - lua.scope(|scope| { - let ud = scope.create_ser_userdata(MyUserData(-5, "test userdata".into()))?; - assert_eq!( - serde_json::to_value(&ud).unwrap(), - serde_json::json!((-5, "test userdata")) - ); - Ok(()) - })?; - - lua.scope(|scope| { - let ud = scope.create_ser_userdata(MyUserData(-5, "test userdata".into()))?; - lua.globals().set("ud", ud) - })?; - let val = lua.load("ud").eval::()?; - match serde_json::to_value(&val) { - Ok(v) => panic!("expected destructed error, got {}", v), - Err(e) if e.to_string().contains("destructed") => {} - Err(e) => panic!("expected destructed error, got {}", e), - } - - struct MyUserDataRef<'a>(&'a ()); - impl<'a> UserData for MyUserDataRef<'a> {} + let json_val = serde_json::json!({ + "a": 1, + "b": "test", + }); + let json_ud = lua.create_ser_any_userdata(json_val).unwrap(); + let json_str = serde_json::to_string_pretty(&json_ud).unwrap(); + assert_eq!(json_str, "{\n \"a\": 1,\n \"b\": \"test\"\n}"); +} - lua.scope(|scope| { - let ud = scope.create_nonstatic_userdata(MyUserDataRef(&()))?; - match serde_json::to_value(&ud) { - Ok(v) => panic!("expected serialization error, got {}", v), - Err(serde_json::Error { .. }) => {} - }; - Ok(()) - })?; +#[test] +fn test_serialize_wrapped_any_userdata() { + let lua = Lua::new(); - Ok(()) + let json_val = serde_json::json!({ + "a": 1, + "b": "test", + }); + let ud = AnyUserData::wrap_ser(json_val); + let json_ud = ud.into_lua(&lua).unwrap(); + let json_str = serde_json::to_string(&json_ud).unwrap(); + assert_eq!(json_str, "{\"a\":1,\"b\":\"test\"}"); } #[test] -fn test_serialize_failure() -> Result<(), Box> { +fn test_serialize_failure() -> Result<(), Box> { #[derive(Serialize)] struct MyUserData(i64); @@ -144,18 +129,12 @@ fn test_serialize_failure() -> Result<(), Box> { Ok(()) } -#[cfg(feature = "luau")] +#[cfg(all(feature = "luau", not(feature = "luau-vector4")))] #[test] -fn test_serialize_vector() -> Result<(), Box> { +fn test_serialize_vector() -> Result<(), Box> { let lua = Lua::new(); - let globals = lua.globals(); - globals.set( - "vector", - lua.create_function(|_, (x, y, z)| Ok(Value::Vector(x, y, z)))?, - )?; - - let val = lua.load("{_vector = vector(1, 2, 3)}").eval::()?; + let val = lua.load("{_vector = vector.create(1, 2, 3)}").eval::()?; let json = serde_json::json!({ "_vector": [1.0, 2.0, 3.0], }); @@ -167,6 +146,168 @@ fn test_serialize_vector() -> Result<(), Box> { Ok(()) } +#[cfg(feature = "luau-vector4")] +#[test] +fn test_serialize_vector() -> Result<(), Box> { + let lua = Lua::new(); + + let val = lua + .load("{_vector = vector.create(1, 2, 3, 4)}") + .eval::()?; + let json = serde_json::json!({ + "_vector": [1.0, 2.0, 3.0, 4.0], + }); + assert_eq!(serde_json::to_value(&val)?, json); + + let expected_json = lua.from_value::(val)?; + assert_eq!(expected_json, json); + + Ok(()) +} + +#[test] +fn test_serialize_sorted() -> LuaResult<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + globals.set("null", lua.null())?; + + let empty_array = lua.create_table()?; + empty_array.set_metatable(Some(lua.array_metatable()))?; + globals.set("empty_array", empty_array)?; + + let value = lua + .load( + r#" + { + _bool = true, + _integer = 123, + _number = 321.99, + _string = "test string serialization", + _table_arr = {null, "value 1", 2, "value 3", {}}, + _table_map = {["table"] = "map", ["null"] = null}, + _bytes = "\240\040\140\040", + _null = null, + _empty_map = {}, + _empty_array = empty_array, + } + "#, + ) + .eval::()?; + + let json = serde_json::to_string(&value.to_serializable().sort_keys(true)).unwrap(); + assert_eq!( + json, + r#"{"_bool":true,"_bytes":[240,40,140,40],"_empty_array":[],"_empty_map":{},"_integer":123,"_null":null,"_number":321.99,"_string":"test string serialization","_table_arr":[null,"value 1",2,"value 3",{}],"_table_map":{"null":null,"table":"map"}}"# + ); + + Ok(()) +} + +#[test] +fn test_serialize_globals() -> LuaResult<()> { + let lua = Lua::new(); + + let globals = Value::Table(lua.globals()); + + // By default it should not work + if let Ok(v) = serde_json::to_value(&globals) { + panic!("expected serialization error, got {v:?}"); + } + + // It should work with `deny_recursive_tables` and `deny_unsupported_types` disabled + if let Err(err) = serde_json::to_value( + globals + .to_serializable() + .deny_recursive_tables(false) + .deny_unsupported_types(false), + ) { + panic!("expected no errors, got {err:?}"); + } + + Ok(()) +} + +#[test] +fn test_serialize_same_table_twice() -> LuaResult<()> { + let lua = Lua::new(); + + let value = lua + .load( + r#" + local foo = {} + return { + a = foo, + b = foo, + } + "#, + ) + .eval::()?; + let json = serde_json::to_string(&value.to_serializable().sort_keys(true)).unwrap(); + assert_eq!(json, r#"{"a":{},"b":{}}"#); + + Ok(()) +} + +#[test] +fn test_serialize_empty_table() -> LuaResult<()> { + let lua = Lua::new(); + + let table = Value::Table(lua.create_table()?); + let json = serde_json::to_string(&table.to_serializable()).unwrap(); + assert_eq!(json, "{}"); + + // Set the option to encode empty tables as array + let json = serde_json::to_string(&table.to_serializable().encode_empty_tables_as_array(true)).unwrap(); + assert_eq!(json, "[]"); + + // Check hashmap table with this option + table.as_table().unwrap().set("hello", "world")?; + let json = serde_json::to_string(&table.to_serializable().encode_empty_tables_as_array(true)).unwrap(); + assert_eq!(json, r#"{"hello":"world"}"#); + + Ok(()) +} + +#[test] +fn test_serialize_mixed_table() -> LuaResult<()> { + let lua = Lua::new(); + + // Check that sparse array is serialized similarly when using direct serialization + // and via `Lua::from_value` + let table = lua.load("{1,2,3,nil,5}").eval::()?; + let json1 = serde_json::to_string(&table).unwrap(); + let json2 = lua.from_value::(table)?; + assert_eq!(json1, json2.to_string()); + + // A table with several borders should be correctly encoded when `detect_mixed_tables` is enabled + let table = lua + .load( + r#" + local t = {1,2,3,nil,5,6} + t[10] = 10 + return t + "#, + ) + .eval::()?; + let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap(); + assert_eq!(json, r#"[1,2,3,null,5,6,null,null,null,10]"#); + + // A mixed table with both array-like and map-like entries + let table = lua.load(r#"{1,2,3, key="value"}"#).eval::()?; + let json = serde_json::to_string(&table).unwrap(); + assert_eq!(json, r#"[1,2,3]"#); + let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap(); + assert_eq!(json, r#"{"1":1,"2":2,"3":3,"key":"value"}"#); + + // A mixed table with duplicate keys of different types + let table = lua.load(r#"{1,2,3, ["1"]="value"}"#).eval::()?; + let json = serde_json::to_string(&table.to_serializable().detect_mixed_tables(true)).unwrap(); + assert_eq!(json, r#"{"1":1,"2":2,"3":3,"1":"value"}"#); + + Ok(()) +} + #[test] fn test_to_value_struct() -> LuaResult<()> { let lua = Lua::new(); @@ -235,7 +376,7 @@ fn test_to_value_enum() -> LuaResult<()> { } #[test] -fn test_to_value_with_options() -> Result<(), Box> { +fn test_to_value_with_options() -> Result<(), Box> { let lua = Lua::new(); let globals = lua.globals(); globals.set("null", lua.null())?; @@ -272,10 +413,7 @@ fn test_to_value_with_options() -> Result<(), Box> { unit: (), unitstruct: UnitStruct, }; - let data2 = lua.to_value_with( - &mydata, - SerializeOptions::new().serialize_none_to_null(false), - )?; + let data2 = lua.to_value_with(&mydata, SerializeOptions::new().serialize_none_to_null(false))?; globals.set("data2", data2)?; lua.load( r#" @@ -287,10 +425,7 @@ fn test_to_value_with_options() -> Result<(), Box> { .exec()?; // serialize_unit_to_null - let data3 = lua.to_value_with( - &mydata, - SerializeOptions::new().serialize_unit_to_null(false), - )?; + let data3 = lua.to_value_with(&mydata, SerializeOptions::new().serialize_unit_to_null(false))?; globals.set("data3", data3)?; lua.load( r#" @@ -305,7 +440,7 @@ fn test_to_value_with_options() -> Result<(), Box> { } #[test] -fn test_from_value_nested_tables() -> Result<(), Box> { +fn test_from_value_nested_tables() -> Result<(), Box> { let lua = Lua::new(); let value = lua @@ -335,7 +470,7 @@ fn test_from_value_nested_tables() -> Result<(), Box> { } #[test] -fn test_from_value_struct() -> Result<(), Box> { +fn test_from_value_struct() -> Result<(), Box> { let lua = Lua::new(); #[derive(Deserialize, PartialEq, Debug)] @@ -345,6 +480,7 @@ fn test_from_value_struct() -> Result<(), Box> { map: HashMap, empty: Vec<()>, tuple: (u8, u8, u8), + bytes: BString, } let value = lua @@ -356,6 +492,7 @@ fn test_from_value_struct() -> Result<(), Box> { map = {2, [4] = 1}, empty = {}, tuple = {10, 20, 30}, + bytes = "\240\040\140\040", } "#, ) @@ -368,6 +505,7 @@ fn test_from_value_struct() -> Result<(), Box> { map: vec![(1, 2), (4, 1)].into_iter().collect(), empty: vec![], tuple: (10, 20, 30), + bytes: BString::from([240, 40, 140, 40]), }, got ); @@ -376,38 +514,64 @@ fn test_from_value_struct() -> Result<(), Box> { } #[test] -fn test_from_value_enum() -> Result<(), Box> { +fn test_from_value_newtype_struct() -> Result<(), Box> { let lua = Lua::new(); #[derive(Deserialize, PartialEq, Debug)] - enum E { + struct Test(f64); + + let got = lua.from_value(Value::Number(123.456))?; + assert_eq!(Test(123.456), got); + + Ok(()) +} + +#[test] +fn test_from_value_enum() -> Result<(), Box> { + let lua = Lua::new(); + lua.globals().set("null", lua.null())?; + + #[derive(Deserialize, PartialEq, Debug)] + struct UnitStruct; + + #[derive(Deserialize, PartialEq, Debug)] + enum E { Unit, Integer(u32), Tuple(u32, u32), Struct { a: u32 }, + Wrap(T), } let value = lua.load(r#""Unit""#).eval()?; - let got = lua.from_value(value)?; + let got: E = lua.from_value(value)?; assert_eq!(E::Unit, got); let value = lua.load(r#"{Integer = 1}"#).eval()?; - let got = lua.from_value(value)?; + let got: E = lua.from_value(value)?; assert_eq!(E::Integer(1), got); let value = lua.load(r#"{Tuple = {1, 2}}"#).eval()?; - let got = lua.from_value(value)?; + let got: E = lua.from_value(value)?; assert_eq!(E::Tuple(1, 2), got); let value = lua.load(r#"{Struct = {a = 3}}"#).eval()?; - let got = lua.from_value(value)?; + let got: E = lua.from_value(value)?; assert_eq!(E::Struct { a: 3 }, got); + let value = lua.load(r#"{Wrap = null}"#).eval()?; + let got = lua.from_value(value)?; + assert_eq!(E::Wrap(UnitStruct), got); + + let value = lua.load(r#"{Wrap = null}"#).eval()?; + let got = lua.from_value(value)?; + assert_eq!(E::Wrap(()), got); + Ok(()) } #[test] -fn test_from_value_enum_untagged() -> Result<(), Box> { +fn test_from_value_enum_untagged() -> Result<(), Box> { let lua = Lua::new(); lua.globals().set("null", lua.null())?; @@ -447,7 +611,7 @@ fn test_from_value_enum_untagged() -> Result<(), Box> { } #[test] -fn test_from_value_with_options() -> Result<(), Box> { +fn test_from_value_with_options() -> Result<(), Box> { let lua = Lua::new(); // Deny unsupported types by default @@ -485,7 +649,7 @@ fn test_from_value_with_options() -> Result<(), Box> { // Check recursion when using `Serialize` impl let t = lua.create_table()?; - t.set("t", t.clone())?; + t.set("t", &t)?; assert!(serde_json::to_string(&t).is_err()); // Serialize Lua globals table @@ -502,3 +666,176 @@ fn test_from_value_with_options() -> Result<(), Box> { Ok(()) } + +#[test] +fn test_from_value_userdata() -> Result<(), Box> { + let lua = Lua::new(); + + // Tuple struct + #[derive(Serialize, Deserialize)] + struct MyUserData(i64, String); + + impl UserData for MyUserData {} + + let ud = lua.create_ser_userdata(MyUserData(123, "test userdata".into()))?; + + match lua.from_value::(Value::UserData(ud)) { + Ok(_) => {} + Err(err) => panic!("expected no errors, got {err:?}"), + }; + + // Newtype struct + #[derive(Serialize, Deserialize)] + struct NewtypeUserdata(String); + + impl UserData for NewtypeUserdata {} + + let ud = lua.create_ser_userdata(NewtypeUserdata("newtype userdata".into()))?; + + match lua.from_value::(Value::UserData(ud)) { + Ok(_) => {} + Err(err) => panic!("expected no errors, got {err:?}"), + }; + + // Option + #[derive(Serialize, Deserialize)] + struct UnitUserdata; + + impl UserData for UnitUserdata {} + + let ud = lua.create_ser_userdata(UnitUserdata)?; + + match lua.from_value::>(Value::UserData(ud)) { + Ok(Some(_)) => {} + Ok(_) => panic!("expected `Some`, got `None`"), + Err(err) => panic!("expected no errors, got {err:?}"), + }; + + // Destructed userdata with skip option + let ud = lua.create_ser_userdata(NewtypeUserdata("newtype userdata".into()))?; + let _ = ud.take::()?; + + match lua.from_value_with::<()>( + Value::UserData(ud), + DeserializeOptions::new().deny_unsupported_types(false), + ) { + Ok(_) => {} + Err(err) => panic!("expected no errors, got {err:?}"), + }; + + Ok(()) +} + +#[test] +fn test_from_value_empty_table() -> Result<(), Box> { + let lua = Lua::new(); + + // By default we encode empty tables as objects + let t = lua.create_table()?; + let got = lua.from_value::(Value::Table(t.clone()))?; + assert_eq!(got, serde_json::json!({})); + + // Set the option to encode empty tables as array + let got = lua + .from_value_with::( + Value::Table(t.clone()), + DeserializeOptions::new().encode_empty_tables_as_array(true), + ) + .unwrap(); + assert_eq!(got, serde_json::json!([])); + + // Check hashmap table with this option + t.raw_set("hello", "world")?; + let got = lua + .from_value_with::( + Value::Table(t), + DeserializeOptions::new().encode_empty_tables_as_array(true), + ) + .unwrap(); + assert_eq!(got, serde_json::json!({"hello": "world"})); + + Ok(()) +} + +#[test] +fn test_from_value_sorted() -> Result<(), Box> { + let lua = Lua::new(); + + let to_json = lua.create_function(|lua, value| { + let json_value: serde_json::Value = + lua.from_value_with(value, DeserializeOptions::new().sort_keys(true))?; + serde_json::to_string(&json_value).into_lua_err() + })?; + lua.globals().set("to_json", to_json)?; + + lua.load( + r#" + local json = to_json({c = 3, b = 2, hello = "world", x = {1}, ["0a"] = {z = "z", d = "d"}}) + assert(json == '{"0a":{"d":"d","z":"z"},"b":2,"c":3,"hello":"world","x":[1]}', "invalid json") + "#, + ) + .exec() + .unwrap(); + + Ok(()) +} + +#[test] +fn test_arbitrary_precision() { + let lua = Lua::new(); + + let opts = SerializeOptions::new().detect_serde_json_arbitrary_precision(true); + + // Number + let num = serde_json::Value::Number(serde_json::Number::from_f64(1.244e2).unwrap()); + let num = lua.to_value_with(&num, opts).unwrap(); + assert_eq!(num, Value::Number(1.244e2)); + + // Integer + let num = serde_json::Value::Number(serde_json::Number::from_f64(123.0).unwrap()); + let num = lua.to_value_with(&num, opts).unwrap(); + assert_eq!(num, Value::Integer(123)); + + // Max u64 + let num = serde_json::Value::Number(serde_json::Number::from(i64::MAX)); + let num = lua.to_value_with(&num, opts).unwrap(); + assert_eq!(num, Value::Number(i64::MAX as f64)); + + // Check that the option is disabled by default + let num = serde_json::Value::Number(serde_json::Number::from_f64(1.244e2).unwrap()); + let num = lua.to_value(&num).unwrap(); + assert_eq!(num.type_name(), "table"); + assert_eq!( + format!("{:#?}", num), + "{\n [\"$serde_json::private::Number\"] = \"124.4\",\n}" + ); +} + +#[cfg(feature = "luau")] +#[test] +fn test_buffer_serialize() -> LuaResult<()> { + let lua = Lua::new(); + + let buf = lua.create_buffer(&[1, 2, 3, 4])?; + let val = serde_value::to_value(&buf).unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4])); + + // Try empty buffer + let buf = lua.create_buffer(&[])?; + let val = serde_value::to_value(&buf).unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![])); + + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_buffer_from_value() -> LuaResult<()> { + let lua = Lua::new(); + + let buf = lua.create_buffer(&[1, 2, 3, 4])?; + let val = lua.from_value::(Value::Buffer(buf)).unwrap(); + assert_eq!(val, serde_value::Value::Bytes(vec![1, 2, 3, 4])); + + Ok(()) +} diff --git a/tests/static.rs b/tests/static.rs deleted file mode 100644 index 5bc286bf..00000000 --- a/tests/static.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::cell::RefCell; - -use mlua::{Lua, Result, Table}; - -#[test] -fn test_static_lua() -> Result<()> { - let lua = Lua::new().into_static(); - - thread_local! { - static TABLE: RefCell>> = RefCell::new(None); - } - - let f = lua.create_function(|_, table: Table| { - TABLE.with(|t| { - table.raw_insert(1, "hello")?; - *t.borrow_mut() = Some(table); - Ok(()) - }) - })?; - - f.call(lua.create_table()?)?; - drop(f); - lua.gc_collect()?; - - TABLE.with(|t| { - assert!(t.borrow().as_ref().unwrap().len().unwrap() == 1); - *t.borrow_mut() = None; - }); - - // Consume the Lua instance - unsafe { Lua::from_static(lua) }; - - Ok(()) -} - -#[test] -fn test_static_lua_coroutine() -> Result<()> { - let lua = Lua::new().into_static(); - - thread_local! { - static TABLE: RefCell>> = RefCell::new(None); - } - - let f = lua.create_function(|_, table: Table| { - TABLE.with(|t| { - table.raw_insert(1, "hello")?; - *t.borrow_mut() = Some(table); - Ok(()) - }) - })?; - - let co = lua.create_thread(f)?; - co.resume::<_, ()>(lua.create_table()?)?; - drop(co); - lua.gc_collect()?; - - TABLE.with(|t| { - assert_eq!( - t.borrow().as_ref().unwrap().get::<_, String>(1i32).unwrap(), - "hello".to_string() - ); - *t.borrow_mut() = None; - }); - - // Consume the Lua instance - unsafe { Lua::from_static(lua) }; - - Ok(()) -} - -#[cfg(feature = "async")] -#[tokio::test] -async fn test_static_async() -> Result<()> { - let lua = Lua::new().into_static(); - - let timer = - lua.create_async_function(|_, (i, n, f): (u64, u64, mlua::Function)| async move { - tokio::task::spawn_local(async move { - let dur = std::time::Duration::from_millis(i); - for _ in 0..n { - tokio::task::spawn_local(f.call_async::<(), ()>(())); - tokio::time::sleep(dur).await; - } - }); - Ok(()) - })?; - lua.globals().set("timer", timer)?; - - { - let local_set = tokio::task::LocalSet::new(); - local_set - .run_until( - lua.load( - r#" - local cnt = 0 - timer(1, 100, function() - cnt = cnt + 1 - if cnt % 10 == 0 then - collectgarbage() - end - end) - "#, - ) - .exec_async(), - ) - .await?; - local_set.await; - } - - // Consume the Lua instance - unsafe { Lua::from_static(lua) }; - - Ok(()) -} diff --git a/tests/string.rs b/tests/string.rs index 35a9a71d..6802d906 100644 --- a/tests/string.rs +++ b/tests/string.rs @@ -1,24 +1,41 @@ use std::borrow::Cow; use std::collections::HashSet; -use mlua::{Lua, Result, String}; +use mlua::{Lua, LuaString, Result}; #[test] fn test_string_compare() { - fn with_str(s: &str, f: F) { - f(Lua::new().create_string(s).unwrap()); + let lua = Lua::new(); + + fn with_str(lua: &Lua, s: &str, f: F) { + f(lua.create_string(s).unwrap()); } // Tests that all comparisons we want to have are usable - with_str("teststring", |t| assert_eq!(t, "teststring")); // &str - with_str("teststring", |t| assert_eq!(t, b"teststring")); // &[u8] - with_str("teststring", |t| assert_eq!(t, b"teststring".to_vec())); // Vec - with_str("teststring", |t| assert_eq!(t, "teststring".to_string())); // String - with_str("teststring", |t| assert_eq!(t, t)); // mlua::String - with_str("teststring", |t| { - assert_eq!(t, Cow::from(b"teststring".as_ref())) - }); // Cow (borrowed) - with_str("bla", |t| assert_eq!(t, Cow::from(b"bla".to_vec()))); // Cow (owned) + with_str(&lua, "teststring", |t| assert_eq!(t, "teststring")); // &str + with_str(&lua, "teststring", |t| assert_eq!(t, b"teststring")); // &[u8] + with_str(&lua, "teststring", |t| assert_eq!(t, b"teststring".to_vec())); // Vec + with_str(&lua, "teststring", |t| assert_eq!(t, "teststring".to_string())); // String + with_str(&lua, "teststring", |t| assert_eq!(t, t)); // mlua::String + with_str(&lua, "teststring", |t| { + assert_eq!(t, Cow::from(b"teststring".as_ref())) // Cow (borrowed) + }); + with_str(&lua, "bla", |t| assert_eq!(t, Cow::from(b"bla".to_vec()))); // Cow (owned) + + // Test ordering + with_str(&lua, "a", |a| { + assert!(!(a < a)); + assert!(!(a > a)); + }); + with_str(&lua, "a", |a| assert!(a < "b")); + with_str(&lua, "a", |a| assert!(a < b"b")); + with_str(&lua, "a", |a| with_str(&lua, "b", |b| assert!(a < b))); + + // Long strings (not interned by Lua) + let long_str = "abc".repeat(100); + with_str(&lua, &long_str, |s1| { + with_str(&lua, &long_str, |s2| assert_eq!(s1, s2)) + }); } #[test] @@ -35,19 +52,13 @@ fn test_string_views() -> Result<()> { .exec()?; let globals = lua.globals(); - let ok: String = globals.get("ok")?; - let err: String = globals.get("err")?; - let empty: String = globals.get("empty")?; + let ok: LuaString = globals.get("ok")?; + let err: LuaString = globals.get("err")?; + let empty: LuaString = globals.get("empty")?; assert_eq!(ok.to_str()?, "null bytes are valid utf-8, wh\0 knew?"); - assert_eq!( - ok.to_string_lossy(), - "null bytes are valid utf-8, wh\0 knew?" - ); - assert_eq!( - ok.as_bytes(), - &b"null bytes are valid utf-8, wh\0 knew?"[..] - ); + assert_eq!(ok.to_string_lossy(), "null bytes are valid utf-8, wh\0 knew?"); + assert_eq!(ok.as_bytes(), &b"null bytes are valid utf-8, wh\0 knew?"[..]); assert!(err.to_str().is_err()); assert_eq!(err.as_bytes(), &b"but \xff isn't :("[..]); @@ -60,7 +71,7 @@ fn test_string_views() -> Result<()> { } #[test] -fn test_raw_string() -> Result<()> { +fn test_string_from_bytes() -> Result<()> { let lua = Lua::new(); let rs = lua.create_string(&[0, 1, 2, 3, 0, 1, 2, 3])?; @@ -73,13 +84,101 @@ fn test_raw_string() -> Result<()> { fn test_string_hash() -> Result<()> { let lua = Lua::new(); - let set: HashSet = lua.load(r#"{"hello", "world", "abc", 321}"#).eval()?; + let set: HashSet = lua.load(r#"{"hello", "world", "abc", 321}"#).eval()?; assert_eq!(set.len(), 4); - assert!(set.contains(b"hello".as_ref())); - assert!(set.contains(b"world".as_ref())); - assert!(set.contains(b"abc".as_ref())); - assert!(set.contains(b"321".as_ref())); - assert!(!set.contains(b"Hello".as_ref())); + assert!(set.contains(&lua.create_string("hello")?)); + assert!(set.contains(&lua.create_string("world")?)); + assert!(set.contains(&lua.create_string("abc")?)); + assert!(set.contains(&lua.create_string("321")?)); + assert!(!set.contains(&lua.create_string("Hello")?)); + + Ok(()) +} + +#[test] +fn test_string_fmt_debug() -> Result<()> { + let lua = Lua::new(); + + // Valid utf8 + let s = lua.create_string("hello")?; + assert_eq!(format!("{s:?}"), r#""hello""#); + assert_eq!(format!("{:?}", s.to_str()?), r#""hello""#); + assert_eq!(format!("{:?}", s.as_bytes()), "[104, 101, 108, 108, 111]"); + + // Invalid utf8 + let s = lua.create_string(b"hello\0world\r\n\t\xf0\x90\x80")?; + assert_eq!(format!("{s:?}"), r#"b"hello\0world\r\n\t\xf0\x90\x80""#); + + Ok(()) +} + +#[test] +fn test_string_pointer() -> Result<()> { + let lua = Lua::new(); + + let str1 = lua.create_string("hello")?; + let str2 = lua.create_string("hello")?; + + // Lua uses string interning, so these should be the same + assert_eq!(str1.to_pointer(), str2.to_pointer()); + + Ok(()) +} + +#[test] +fn test_string_display() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_string("hello")?; + assert_eq!(format!("{}", s.display()), "hello"); + + // With invalid utf8 + let s = lua.create_string(b"hello\0world\xFF")?; + assert_eq!(format!("{}", s.display()), "hello\0world�"); + + Ok(()) +} + +#[test] +fn test_string_wrap() -> Result<()> { + let lua = Lua::new(); + + let s = LuaString::wrap("hello, world"); + lua.globals().set("s", s)?; + assert_eq!(lua.globals().get::("s")?, "hello, world"); + + let s2 = LuaString::wrap("hello, world (owned)".to_string()); + lua.globals().set("s2", s2)?; + assert_eq!(lua.globals().get::("s2")?, "hello, world (owned)"); + + Ok(()) +} + +#[test] +fn test_bytes_into_iter() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_string("hello")?; + let bytes = s.as_bytes(); + + for (i, &b) in bytes.into_iter().enumerate() { + assert_eq!(b, s.as_bytes()[i]); + } + + Ok(()) +} + +#[cfg(feature = "lua55")] +#[test] +fn test_external_string() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_external_string(b"abc\0")?; + assert_eq!( + s.as_bytes(), + b"abc\0", + "Trailing null byte should be preserved if present explicitly" + ); Ok(()) } diff --git a/tests/table.rs b/tests/table.rs index 3ee0797f..e0bc5f44 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -1,14 +1,16 @@ -use mlua::{Error, Lua, Nil, Result, Table, TableExt, Value}; +use mlua::{Error, Lua, ObjectLike, Result, Table, Value}; #[test] -fn test_set_get() -> Result<()> { +fn test_globals_set_get() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); globals.set("foo", "bar")?; globals.set("baz", "baf")?; - assert_eq!(globals.get::<_, String>("foo")?, "bar"); - assert_eq!(globals.get::<_, String>("baz")?, "baf"); + assert_eq!(globals.get::("foo")?, "bar"); + assert_eq!(globals.get::("baz")?, "baf"); + + lua.load(r#"assert(foo == "bar")"#).exec().unwrap(); Ok(()) } @@ -19,16 +21,6 @@ fn test_table() -> Result<()> { let globals = lua.globals(); - globals.set("table", lua.create_table()?)?; - let table1: Table = globals.get("table")?; - let table2: Table = globals.get("table")?; - - table1.set("foo", "bar")?; - table2.set("baz", "baf")?; - - assert_eq!(table2.get::<_, String>("foo")?, "bar"); - assert_eq!(table1.get::<_, String>("baz")?, "baf"); - lua.load( r#" table1 = {1, 2, 3, 4, 5} @@ -38,74 +30,165 @@ fn test_table() -> Result<()> { ) .exec()?; - let table1 = globals.get::<_, Table>("table1")?; - let table2 = globals.get::<_, Table>("table2")?; - let table3 = globals.get::<_, Table>("table3")?; - + let table1 = globals.get::
("table1")?; assert_eq!(table1.len()?, 5); + assert!(!table1.is_empty()); assert_eq!( - table1 - .clone() - .pairs() - .collect::>>()?, + table1.pairs().collect::>>()?, vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] ); assert_eq!( - table1 - .clone() - .sequence_values() - .collect::>>()?, + table1.sequence_values().collect::>>()?, vec![1, 2, 3, 4, 5] ); + assert_eq!(table1, [1, 2, 3, 4, 5]); + assert_eq!(table1, [1, 2, 3, 4, 5].as_slice()); + let table2 = globals.get::
("table2")?; assert_eq!(table2.len()?, 0); - assert_eq!( - table2 - .clone() - .pairs() - .collect::>>()?, - vec![] - ); - assert_eq!( - table2.sequence_values().collect::>>()?, - vec![] - ); + assert!(table2.is_empty()); + assert_eq!(table2.pairs().collect::>>()?, vec![]); + assert_eq!(table2, [0; 0]); + let table3 = globals.get::
("table3")?; // sequence_values should only iterate until the first border + assert_eq!(table3, [1, 2]); assert_eq!( table3.sequence_values().collect::>>()?, vec![1, 2] ); - globals.set("table4", lua.create_sequence_from(vec![1, 2, 3, 4, 5])?)?; - let table4 = globals.get::<_, Table>("table4")?; + Ok(()) +} + +#[test] +#[cfg(target_os = "linux")] // Linux allow overcommiting the memory (relevant for CI) +fn test_table_with_large_capacity() { + let lua = Lua::new(); + + let t = lua.create_table_with_capacity(1 << 26, 1 << 26); + assert!(t.is_ok()); +} + +#[test] +fn test_table_push_pop() -> Result<()> { + let lua = Lua::new(); + + // Test raw access + let table1 = lua.create_sequence_from([123])?; + table1.raw_push(321)?; + assert_eq!(table1, [123, 321]); + assert_eq!(table1.raw_pop::()?, 321); + assert_eq!(table1.raw_pop::()?, 123); + assert_eq!(table1.raw_pop::()?, Value::Nil); // An extra pop should do nothing + assert_eq!(table1.raw_len(), 0); + assert_eq!(table1, [0; 0]); + + // Test access through metamethods + let table2 = lua + .load( + r#" + local proxy_table = {234} + table2 = setmetatable({}, { + __len = function() return #proxy_table end, + __index = proxy_table, + __newindex = proxy_table, + }) + return table2 + "#, + ) + .eval::
()?; + table2.push(345)?; + assert_eq!(table2.len()?, 2); assert_eq!( - table4 - .clone() - .pairs() - .collect::>>()?, - vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] + table2.sequence_values::().collect::>>()?, + vec![] ); + assert_eq!(table2.pop::()?, 345); + assert_eq!(table2.pop::()?, 234); + assert_eq!(table2.pop::()?, Value::Nil); + assert_eq!(table2.len()?, 0); + + Ok(()) +} +#[test] +fn test_table_insert_remove() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + + globals.set("table4", [1, 2, 3, 4, 5])?; + let table4 = globals.get::
("table4")?; + assert_eq!( + table4.pairs().collect::>>()?, + vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] + ); table4.raw_insert(4, 35)?; table4.raw_insert(7, 7)?; assert_eq!( - table4 - .clone() - .pairs() - .collect::>>()?, + table4.pairs().collect::>>()?, vec![(1, 1), (2, 2), (3, 3), (4, 35), (5, 4), (6, 5), (7, 7)] ); - table4.raw_remove(1)?; assert_eq!( - table4 - .clone() - .pairs() - .collect::>>()?, + table4.pairs().collect::>>()?, vec![(1, 2), (2, 3), (3, 35), (4, 4), (5, 5), (6, 7)] ); + // Wrong index, tables are 1-indexed + assert!(table4.raw_insert(0, "123").is_err()); + + Ok(()) +} + +#[test] +fn test_table_clear() -> Result<()> { + let lua = Lua::new(); + + let t = lua.create_table()?; + + // Check readonly error + #[cfg(feature = "luau")] + { + t.set_readonly(true); + assert!(matches!( + t.clear(), + Err(Error::RuntimeError(err)) if err.contains("attempt to modify a readonly table") + )); + t.set_readonly(false); + } + + // Set array and hash parts + t.push("abc")?; + t.push("bcd")?; + t.set("a", "1")?; + t.set("b", "2")?; + t.clear()?; + assert_eq!(t.len()?, 0); + assert_eq!(t.pairs::().count(), 0); + + // Test table with metamethods + let t2 = lua + .load( + r#" + setmetatable({1, 2, 3, a = "1"}, { + __index = function() error("index error") end, + __newindex = function() error("newindex error") end, + __len = function() error("len error") end, + __pairs = function() error("pairs error") end, + }) + "#, + ) + .eval::
()?; + assert_eq!(t2.raw_len(), 3); + assert!(!t2.is_empty()); + t2.clear()?; + assert_eq!(t2.raw_len(), 0); + assert!(t2.is_empty()); + assert_eq!(t2.raw_get::("a")?, Value::Nil); + assert_ne!(t2.metatable(), None); + Ok(()) } @@ -115,29 +198,92 @@ fn test_table_sequence_from() -> Result<()> { let get_table = lua.create_function(|_, t: Table| Ok(t))?; - assert_eq!( - get_table - .call::<_, Table>(vec![1, 2, 3])? - .sequence_values() - .collect::>>()?, - vec![1, 2, 3] - ); + assert_eq!(get_table.call::
(vec![1, 2, 3])?, [1, 2, 3]); + assert_eq!(get_table.call::
([4, 5, 6])?, [4, 5, 6]); + assert_eq!(get_table.call::
([7, 8, 9].as_slice())?, [7, 8, 9]); - assert_eq!( - get_table - .call::<_, Table>([1, 2, 3].as_ref())? - .sequence_values() - .collect::>>()?, - vec![1, 2, 3] - ); + Ok(()) +} - assert_eq!( - get_table - .call::<_, Table>([1, 2, 3])? - .sequence_values() - .collect::>>()?, - vec![1, 2, 3] - ); +#[test] +fn test_table_pairs() -> Result<()> { + let lua = Lua::new(); + + let table = lua + .load( + r#" + { + foo = "bar", + baz = "baf", + [123] = 456, + [789] = 101112, + 5, + } + "#, + ) + .eval::
()?; + + for (i, kv) in table.pairs::().enumerate() { + let (k, _v) = kv.unwrap(); + match i { + // Try to add a new key + 0 => table.set("new_key", "new_value")?, + // Try to delete the 2nd key + 1 => { + table.set(k, Value::Nil)?; + lua.gc_collect()?; + } + _ => {} + } + } + + Ok(()) +} + +#[test] +fn test_table_for_each() -> Result<()> { + let lua = Lua::new(); + + let table = lua + .load( + r#" + { + foo = "bar", + baz = "baf", + [123] = 456, + [789] = 101112, + 5, + } + "#, + ) + .eval::
()?; + + let mut i = 0; + table.for_each::(|k, _| { + if i == 0 { + // Delete first key + table.set(k, Value::Nil)?; + lua.gc_collect()?; + } + Ok(i += 1) + })?; + assert_eq!(i, 5); + + Ok(()) +} + +#[test] +fn test_table_for_each_value() -> Result<()> { + let lua = Lua::new(); + + let table = lua.load("{1, 2, 3, 4, 5, nil, 7}").eval::
()?; + let mut sum = 0; + table.for_each_value::(|v| { + sum += v; + Ok(()) + })?; + // Iterations stops at the first nil + assert_eq!(sum, 1 + 2 + 3 + 4 + 5); Ok(()) } @@ -159,13 +305,13 @@ fn test_table_scope() -> Result<()> { // Make sure that table gets do not borrow the table, but instead just borrow lua. let tin; { - let touter = globals.get::<_, Table>("touter")?; - tin = touter.get::<_, Table>("tin")?; + let touter = globals.get::
("touter")?; + tin = touter.get::
("tin")?; } - assert_eq!(tin.get::<_, i64>(1)?, 1); - assert_eq!(tin.get::<_, i64>(2)?, 2); - assert_eq!(tin.get::<_, i64>(3)?, 3); + assert_eq!(tin.get::(1)?, 1); + assert_eq!(tin.get::(2)?, 2); + assert_eq!(tin.get::(3)?, 3); Ok(()) } @@ -177,23 +323,17 @@ fn test_metatable() -> Result<()> { let table = lua.create_table()?; let metatable = lua.create_table()?; metatable.set("__index", lua.create_function(|_, ()| Ok("index_value"))?)?; - table.set_metatable(Some(metatable)); - assert_eq!(table.get::<_, String>("any_key")?, "index_value"); - match table.raw_get::<_, Value>("any_key")? { - Nil => {} - _ => panic!(), - } - table.set_metatable(None); - match table.get::<_, Value>("any_key")? { - Nil => {} - _ => panic!(), - }; + table.set_metatable(Some(metatable))?; + assert_eq!(table.get::("any_key")?, "index_value"); + assert_eq!(table.raw_get::("any_key")?, Value::Nil); + table.set_metatable(None)?; + assert_eq!(table.get::("any_key")?, Value::Nil); Ok(()) } #[test] -fn test_table_eq() -> Result<()> { +fn test_table_equals() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); @@ -211,10 +351,10 @@ fn test_table_eq() -> Result<()> { ) .exec()?; - let table1 = globals.get::<_, Table>("table1")?; - let table2 = globals.get::<_, Table>("table2")?; - let table3 = globals.get::<_, Table>("table3")?; - let table4 = globals.get::<_, Table>("table4")?; + let table1 = globals.get::
("table1")?; + let table2 = globals.get::
("table2")?; + let table3 = globals.get::
("table3")?; + let table4 = globals.get::
("table4")?; assert!(table1 != table2); assert!(!table1.equals(&table2)?); @@ -226,6 +366,20 @@ fn test_table_eq() -> Result<()> { Ok(()) } +#[test] +fn test_table_pointer() -> Result<()> { + let lua = Lua::new(); + + let table1 = lua.create_table()?; + let table2 = lua.create_table()?; + + // Clone should not create a new table + assert_eq!(table1.to_pointer(), table1.clone().to_pointer()); + assert_ne!(table1.to_pointer(), table2.to_pointer()); + + Ok(()) +} + #[test] fn test_table_error() -> Result<()> { let lua = Lua::new(); @@ -251,17 +405,51 @@ fn test_table_error() -> Result<()> { let bad_table: Table = globals.get("table")?; assert!(bad_table.set(1, 1).is_err()); - assert!(bad_table.get::<_, i32>(1).is_err()); + assert!(bad_table.get::(1).is_err()); assert!(bad_table.len().is_err()); assert!(bad_table.raw_set(1, 1).is_ok()); - assert!(bad_table.raw_get::<_, i32>(1).is_ok()); + assert!(bad_table.raw_get::(1).is_ok()); assert_eq!(bad_table.raw_len(), 1); Ok(()) } #[test] -fn test_table_call() -> Result<()> { +fn test_table_fmt() -> Result<()> { + let lua = Lua::new(); + + let table = lua + .load( + r#" + local t = {1, 2, 3, a = 5, b = { 6 }} + t["special-"] = 10 + t[9.2] = 9.2 + t[1.99] = 1.99 + t[true] = true + t[false] = false + return t + "#, + ) + .eval::
()?; + assert!(format!("{table:?}").starts_with("Table(Ref(")); + + // Pretty print + assert_eq!( + format!("{table:#?}"), + "{\n [false] = false,\n [true] = true,\n [1] = 1,\n [1.99] = 1.99,\n [2] = 2,\n [3] = 3,\n [9.2] = 9.2,\n a = 5,\n b = {\n 6,\n },\n [\"special-\"] = 10,\n}" + ); + + let table2 = lua.create_table_from([("1", "first"), ("2", "second")])?; + assert_eq!( + format!("{table2:#?}"), + "{\n [\"1\"] = \"first\",\n [\"2\"] = \"second\",\n}" + ); + + Ok(()) +} + +#[test] +fn test_table_object_like() -> Result<()> { let lua = Lua::new(); lua.load( @@ -270,6 +458,10 @@ fn test_table_call() -> Result<()> { setmetatable(table, { __call = function(t, key) return "call_"..t[key] + end, + + __tostring = function() + return "table object" end }) @@ -286,19 +478,104 @@ fn test_table_call() -> Result<()> { let table: Table = lua.globals().get("table")?; - assert_eq!(table.call::<_, String>("b")?, "call_2"); - assert_eq!(table.call_function::<_, _, String>("func", "a")?, "func_a"); - assert_eq!( - table.call_method::<_, _, String>("method", "a")?, - "method_1" - ); +
::set(&table, "c", 3)?; + assert_eq!(
::get::(&table, "c")?, 3); + assert_eq!(table.call::("b")?, "call_2"); + assert_eq!(table.call_function::("func", "a")?, "func_a"); + assert_eq!(table.call_method::("method", "a")?, "method_1"); + assert_eq!(table.to_string()?, "table object"); + + match table.call_method::<()>("non_existent", ()) { + Err(Error::RuntimeError(err)) => { + assert!(err.contains("attempt to call a nil value (function 'non_existent')")) + } + r => panic!("expected RuntimeError, got {r:?}"), + } // Test calling non-callable table let table2 = lua.create_table()?; - assert!(matches!( - table2.call::<_, ()>(()), - Err(Error::RuntimeError(_)) - )); + assert!(matches!(table2.call::<()>(()), Err(Error::RuntimeError(_)))); + + Ok(()) +} + +#[test] +fn test_table_get_path() -> Result<()> { + let lua = Lua::new(); + + // Create a nested table structure + let table = lua + .load( + r#" + { + a = { + b = { + c = "hello", + d = 42 + }, + [1] = "first", + ["special key"] = "special value" + }, + abc = "top level", + x = {}, + ["🚀"] = "rocket", + [1] = { + ["nested-key"] = { + [42] = { + final = "hello!", + }, + }, + ["key\"with\"quotes"] = "value1", + ["key'with'quotes"] = "value2", + ["key\\with\\backslashes"] = "value3", + [-2] = "negative index", + }, + } + "#, + ) + .eval::
()?; + + // Test basic dot notation + assert_eq!(table.get_path::(".a.b.c")?, "hello"); + assert_eq!(table.get_path::("a.b.c")?, "hello"); + assert_eq!(table.get_path::("a.b.d")?, 42); + assert_eq!(table.get_path::("abc")?, "top level"); + + // Test bracket notation with integer keys + assert_eq!(table.get_path::("a[1]")?, "first"); + assert_eq!(table.get_path::("[1][-2]")?, "negative index"); + + // Test bracket notation with string keys + assert_eq!(table.get_path::("a[\"special key\"]")?, "special value"); + assert_eq!(table.get_path::("a['special key']")?, "special value"); + assert_eq!(table.get_path::(r#"[1]["key\"with\"quotes"]"#)?, "value1"); + assert_eq!(table.get_path::(r#"[1]['key"with"quotes']"#)?, "value1"); + assert_eq!(table.get_path::(r#"[1]['key\'with\'quotes']"#)?, "value2"); + assert_eq!( + table.get_path::(r#"[1]["key\\with\\backslashes"]"#)?, + "value3" + ); + + // Test mixed notation + assert_eq!(table.get_path::("[1].nested-key[42].final")?, "hello!"); + + // Test unicode keys + assert_eq!(table.get_path::("🚀")?, "rocket"); + + // Test empty path returns the table itself + assert_eq!(table.get_path::
("")?, table); + + // Test safe navigation + assert_eq!(table.get_path::("a?.b.c")?, "hello"); + assert_eq!(table.get_path::("x.y?.z")?, Value::Nil); + assert_eq!(table.get_path::("[1].nested-key[43]?.final")?, Value::Nil); + + // Test path with whitespace + assert_eq!(table.get_path::(" .a [\"b\"] .c ")?, "hello"); + + // Test indexing non-indexable value + let err = table.get_path::("abc.c").unwrap_err().to_string(); + assert_eq!(err, "runtime error: attempt to index a string value with key 'c'"); Ok(()) } diff --git a/tests/tests.rs b/tests/tests.rs index 8dc80ad7..99716f63 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,22 +1,39 @@ use std::collections::HashMap; +#[cfg(not(target_arch = "wasm32"))] use std::iter::FromIterator; -use std::panic::{catch_unwind, AssertUnwindSafe}; -use std::string::String as StdString; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::panic::{AssertUnwindSafe, catch_unwind}; use std::sync::Arc; use std::{error, f32, f64, fmt}; use mlua::{ - ChunkMode, Error, ExternalError, Function, Lua, LuaOptions, Nil, Result, StdLib, String, Table, - UserData, Value, Variadic, + ChunkMode, Error, ExternalError, Function, Lua, LuaOptions, Nil, Result, StdLib, Table, UserData, Value, + Variadic, ffi, }; +#[test] +fn test_weak_lua() { + let lua = Lua::new(); + let weak_lua = lua.weak(); + assert!(weak_lua.try_upgrade().is_some()); + drop(lua); + assert!(weak_lua.try_upgrade().is_none()); +} + +#[test] +#[should_panic(expected = "Lua instance is destroyed")] +fn test_weak_lua_panic() { + let lua = Lua::new(); + let weak_lua = lua.weak(); + drop(lua); + let _ = weak_lua.upgrade(); +} + #[cfg(not(feature = "luau"))] #[test] fn test_safety() -> Result<()> { let lua = Lua::new(); assert!(lua.load(r#"require "debug""#).exec().is_err()); - match lua.load_from_std_lib(StdLib::DEBUG) { + match lua.load_std_libs(StdLib::DEBUG) { Err(Error::SafetyError(_)) => {} Err(e) => panic!("expected SafetyError, got {:?}", e), Ok(_) => panic!("expected SafetyError, got no error"), @@ -51,8 +68,8 @@ fn test_safety() -> Result<()> { // Test safety rules after dynamically loading `package` library let lua = Lua::new_with(StdLib::NONE, LuaOptions::default())?; - assert!(lua.globals().get::<_, Option>("require")?.is_none()); - lua.load_from_std_lib(StdLib::PACKAGE)?; + assert!(lua.globals().get::>("require")?.is_none()); + lua.load_std_libs(StdLib::PACKAGE)?; match lua.load(r#"package.loadlib()"#).exec() { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { Error::SafetyError(_) => {} @@ -69,10 +86,11 @@ fn test_safety() -> Result<()> { fn test_load() -> Result<()> { let lua = Lua::new(); - let func = lua.load("return 1+2").into_function()?; + let func = lua.load("\treturn 1+2").into_function()?; let result: i32 = func.call(())?; assert_eq!(result, 3); + assert!(lua.load("").exec().is_ok()); assert!(lua.load("§$%§&$%&").exec().is_err()); Ok(()) @@ -89,7 +107,7 @@ fn test_exec() -> Result<()> { "#, ) .exec()?; - assert_eq!(globals.get::<_, String>("res")?, "foobar"); + assert_eq!(globals.get::("res")?, "foobar"); let module: Table = lua .load( @@ -104,12 +122,8 @@ fn test_exec() -> Result<()> { "#, ) .eval()?; - println!("checkpoint"); assert!(module.contains_key("func")?); - assert_eq!( - module.get::<_, Function>("func")?.call::<_, String>(())?, - "hello" - ); + assert_eq!(module.get::("func")?.call::(())?, "hello"); Ok(()) } @@ -126,10 +140,32 @@ fn test_eval() -> Result<()> { incomplete_input: true, .. }) => {} - r => panic!( - "expected SyntaxError with incomplete_input=true, got {:?}", - r - ), + r => panic!("expected SyntaxError with incomplete_input=true, got {:?}", r), + } + + Ok(()) +} + +#[test] +fn test_replace_globals() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.create_table()?; + globals.set("foo", "bar")?; + + lua.set_globals(globals.clone())?; + let val = lua.load("return foo").eval::()?; + assert_eq!(val, "bar"); + + // Updating globals in sandboxed Lua state is not allowed + #[cfg(feature = "luau")] + { + lua.sandbox(true)?; + match lua.set_globals(globals) { + Err(Error::RuntimeError(msg)) + if msg.contains("cannot change globals in a sandboxed Lua state") => {} + r => panic!("expected RuntimeError(...) with a specific error message, got {r:?}"), + } } Ok(()) @@ -139,10 +175,7 @@ fn test_eval() -> Result<()> { fn test_load_mode() -> Result<()> { let lua = unsafe { Lua::unsafe_new() }; - assert_eq!( - lua.load("1 + 1").set_mode(ChunkMode::Text).eval::()?, - 2 - ); + assert_eq!(lua.load("1 + 1").set_mode(ChunkMode::Text).eval::()?, 2); match lua.load("1 + 1").set_mode(ChunkMode::Binary).exec() { Ok(_) => panic!("expected SyntaxError, got no error"), Err(Error::SyntaxError { message: msg, .. }) => { @@ -154,14 +187,9 @@ fn test_load_mode() -> Result<()> { #[cfg(not(feature = "luau"))] let bytecode = lua.load("return 1 + 1").into_function()?.dump(true); #[cfg(feature = "luau")] - let bytecode = mlua::Compiler::new().compile("return 1 + 1"); + let bytecode = mlua::Compiler::new().compile("return 1 + 1")?; assert_eq!(lua.load(&bytecode).eval::()?, 2); - assert_eq!( - lua.load(&bytecode) - .set_mode(ChunkMode::Binary) - .eval::()?, - 2 - ); + assert_eq!(lua.load(&bytecode).set_mode(ChunkMode::Binary).eval::()?, 2); match lua.load(&bytecode).set_mode(ChunkMode::Text).exec() { Ok(_) => panic!("expected SyntaxError, got no error"), Err(Error::SyntaxError { message: msg, .. }) => { @@ -191,13 +219,13 @@ fn test_lua_multi() -> Result<()> { .exec()?; let globals = lua.globals(); - let concat = globals.get::<_, Function>("concat")?; - let mreturn = globals.get::<_, Function>("mreturn")?; + let concat = globals.get::("concat")?; + let mreturn = globals.get::("mreturn")?; - assert_eq!(concat.call::<_, String>(("foo", "bar"))?, "foobar"); - let (a, b) = mreturn.call::<_, (u64, u64)>(())?; + assert_eq!(concat.call::(("foo", "bar"))?, "foobar"); + let (a, b) = mreturn.call::<(u64, u64)>(())?; assert_eq!((a, b), (1, 2)); - let (a, b, v) = mreturn.call::<_, (u64, u64, Variadic)>(())?; + let (a, b, v) = mreturn.call::<(u64, u64, Variadic)>(())?; assert_eq!((a, b), (1, 2)); assert_eq!(v[..], [3, 4, 5, 6]); @@ -219,10 +247,10 @@ fn test_coercion() -> Result<()> { .exec()?; let globals = lua.globals(); - assert_eq!(globals.get::<_, String>("int")?, "123"); - assert_eq!(globals.get::<_, i32>("str")?, 123); - assert_eq!(globals.get::<_, i32>("num")?, 123); - assert!(globals.get::<_, String>("func").is_err()); + assert_eq!(globals.get::("int")?, "123"); + assert_eq!(globals.get::("str")?, 123); + assert_eq!(globals.get::("num")?, 123); + assert!(globals.get::("func").is_err()); Ok(()) } @@ -275,7 +303,10 @@ fn test_error() -> Result<()> { end, 3) local function handler(err) - if string.match(_VERSION, ' 5%.1$') or string.match(_VERSION, ' 5%.2$') or _VERSION == "Luau" then + if string.match(_VERSION, " 5%.1$") + or string.match(_VERSION, " 5%.2$") + or string.match(_VERSION, "Luau") + then -- Special case for Lua 5.1/5.2 and Luau local caps = string.match(err, ': (%d+)$') if caps then @@ -303,41 +334,36 @@ fn test_error() -> Result<()> { ) .exec()?; - let rust_error_function = - lua.create_function(|_, ()| -> Result<()> { Err(TestError.to_lua_err()) })?; + let rust_error_function = lua.create_function(|_, ()| -> Result<()> { Err(TestError.into_lua_err()) })?; globals.set("rust_error_function", rust_error_function)?; - let no_error = globals.get::<_, Function>("no_error")?; - let lua_error = globals.get::<_, Function>("lua_error")?; - let rust_error = globals.get::<_, Function>("rust_error")?; - let return_error = globals.get::<_, Function>("return_error")?; - let return_string_error = globals.get::<_, Function>("return_string_error")?; - let test_pcall = globals.get::<_, Function>("test_pcall")?; - let understand_recursion = globals.get::<_, Function>("understand_recursion")?; + let no_error = globals.get::("no_error")?; + assert!(no_error.call::<()>(()).is_ok()); - assert!(no_error.call::<_, ()>(()).is_ok()); - match lua_error.call::<_, ()>(()) { + let lua_error = globals.get::("lua_error")?; + match lua_error.call::<()>(()) { Err(Error::RuntimeError(_)) => {} Err(e) => panic!("error is not RuntimeError kind, got {:?}", e), _ => panic!("error not returned"), } - match rust_error.call::<_, ()>(()) { + + let rust_error = globals.get::("rust_error")?; + match rust_error.call::<()>(()) { Err(Error::CallbackError { .. }) => {} Err(e) => panic!("error is not CallbackError kind, got {:?}", e), _ => panic!("error not returned"), } - match return_error.call::<_, Value>(()) { + let return_error = globals.get::("return_error")?; + match return_error.call::(()) { Ok(Value::Error(_)) => {} _ => panic!("Value::Error not returned"), } - assert!(return_string_error.call::<_, Error>(()).is_ok()); + let return_string_error = globals.get::("return_string_error")?; + assert!(return_string_error.call::(()).is_ok()); - match lua - .load("if youre happy and you know it syntax error") - .exec() - { + match lua.load("if you are happy and you know it syntax error").exec() { Err(Error::SyntaxError { incomplete_input: false, .. @@ -354,26 +380,30 @@ fn test_error() -> Result<()> { _ => panic!("error not returned"), } - test_pcall.call::<_, ()>(())?; + let test_pcall = globals.get::("test_pcall")?; + test_pcall.call::<()>(())?; - assert!(understand_recursion.call::<_, ()>(()).is_err()); + #[cfg(not(target_arch = "wasm32"))] + { + let understand_recursion = globals.get::("understand_recursion")?; + assert!(understand_recursion.call::<()>(()).is_err()); + } Ok(()) } #[test] +#[cfg(not(panic = "abort"))] fn test_panic() -> Result<()> { fn make_lua(options: LuaOptions) -> Result { let lua = Lua::new_with(StdLib::ALL_SAFE, options)?; - let rust_panic_function = - lua.create_function(|_, msg: Option| -> Result<()> { - if let Some(msg) = msg { - panic!("{}", msg) - } - panic!("rust panic") - })?; - lua.globals() - .set("rust_panic_function", rust_panic_function)?; + let rust_panic_function = lua.create_function(|_, msg: Option| -> Result<()> { + if let Some(msg) = msg { + panic!("{}", msg) + } + panic!("rust panic") + })?; + lua.globals().set("rust_panic_function", rust_panic_function)?; Ok(lua) } @@ -407,7 +437,7 @@ fn test_panic() -> Result<()> { { let lua = make_lua(LuaOptions::default())?; match catch_unwind(AssertUnwindSafe(|| -> Result<()> { - let _catched_panic = lua + let _caught_panic = lua .load( r#" -- Set global @@ -422,7 +452,7 @@ fn test_panic() -> Result<()> { Err(_) => {} }; - assert!(lua.globals().get::<_, Value>("err")? == Value::Nil); + assert!(lua.globals().get::("err")? == Value::Nil); match lua.load("tostring(err)").exec() { Ok(_) => panic!("no error was detected"), Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { @@ -465,7 +495,7 @@ fn test_panic() -> Result<()> { .exec() }) { Ok(r) => panic!("no panic was detected: {:?}", r), - Err(p) => assert!(*p.downcast::().unwrap() == "rust panic from lua"), + Err(p) => assert!(*p.downcast::().unwrap() == "rust panic from lua"), } // Test disabling `catch_rust_panics` option / xpcall correctness @@ -489,39 +519,31 @@ fn test_panic() -> Result<()> { .exec() }) { Ok(r) => panic!("no panic was detected: {:?}", r), - Err(p) => assert!(*p.downcast::().unwrap() == "rust panic from lua"), + Err(p) => assert!(*p.downcast::().unwrap() == "rust panic from lua"), } Ok(()) } +#[cfg(target_pointer_width = "64")] #[test] -fn test_result_conversions() -> Result<()> { - let lua = Lua::new(); - let globals = lua.globals(); +fn test_safe_integers() -> Result<()> { + const MAX_SAFE_INTEGER: i64 = 2i64.pow(53) - 1; + const MIN_SAFE_INTEGER: i64 = -2i64.pow(53) + 1; - let err = lua.create_function(|_, ()| { - Ok(Err::( - "only through failure can we succeed".to_lua_err(), - )) - })?; - let ok = lua.create_function(|_, ()| Ok(Ok::<_, Error>("!".to_owned())))?; + let lua = Lua::new(); + let f = lua.load("return ...").into_function()?; - globals.set("err", err)?; - globals.set("ok", ok)?; + assert_eq!(f.call::(MAX_SAFE_INTEGER)?, MAX_SAFE_INTEGER); + assert_eq!(f.call::(MIN_SAFE_INTEGER)?, MIN_SAFE_INTEGER); - lua.load( - r#" - local r, e = err() - assert(r == nil) - assert(tostring(e):find("only through failure can we succeed") ~= nil) - - local r, e = ok() - assert(r == "!") - assert(e == nil) - "#, - ) - .exec()?; + // For Lua versions that does not support 64-bit integers, the values will be converted to f64 + #[cfg(any(feature = "luau", feature = "lua51", feature = "luajit"))] + { + assert_ne!(f.call::(MAX_SAFE_INTEGER + 2)?, MAX_SAFE_INTEGER + 2); + assert_ne!(f.call::(MIN_SAFE_INTEGER - 2)?, MIN_SAFE_INTEGER - 2); + assert_eq!(f.call::(i64::MAX)?, i64::MAX as f64); + } Ok(()) } @@ -558,9 +580,9 @@ fn test_num_conversion() -> Result<()> { assert_eq!(lua.load("1.0").eval::()?, 1); assert_eq!(lua.load("1.0").eval::()?, 1.0); - #[cfg(any(feature = "lua54", feature = "lua53"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] assert_eq!(lua.load("1.0").eval::()?, "1.0"); - #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] + #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit", feature = "luau"))] assert_eq!(lua.load("1.0").eval::()?, "1"); assert_eq!(lua.load("1.5").eval::()?, 1); @@ -580,6 +602,21 @@ fn test_num_conversion() -> Result<()> { assert_eq!(lua.unpack::(lua.pack(1i128 << 64)?)?, 1i128 << 64); + // Negative zero + let negative_zero = lua.load("-0.0").eval::()?; + assert_eq!(negative_zero, 0.0); + // LuaJIT treats -0.0 as a positive zero + #[cfg(not(feature = "luajit"))] + assert!(negative_zero.is_sign_negative()); + + // In Lua <5.3 all numbers are floats + #[cfg(not(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luajit")))] + { + let negative_zero = lua.load("-0").eval::()?; + assert_eq!(negative_zero, 0.0); + assert!(negative_zero.is_sign_negative()); + } + Ok(()) } @@ -633,25 +670,25 @@ fn test_pcall_xpcall() -> Result<()> { ) .exec()?; - assert_eq!(globals.get::<_, bool>("pcall_status")?, false); - assert_eq!(globals.get::<_, String>("pcall_error")?, "testerror"); + assert_eq!(globals.get::("pcall_status")?, false); + assert_eq!(globals.get::("pcall_error")?, "testerror"); - assert_eq!(globals.get::<_, bool>("xpcall_statusr")?, false); + assert_eq!(globals.get::("xpcall_statusr")?, false); #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "luajit" ))] - assert_eq!( - globals.get::<_, std::string::String>("xpcall_error")?, - "testerror" - ); + assert_eq!(globals.get::("xpcall_error")?, "testerror"); #[cfg(feature = "lua51")] - assert!(globals - .get::<_, String>("xpcall_error")? - .to_str()? - .ends_with(": testerror")); + assert!( + globals + .get::("xpcall_error")? + .to_str()? + .ends_with(": testerror") + ); // Make sure that weird xpcall error recursion at least doesn't cause unsafety or panics. lua.load( @@ -662,9 +699,7 @@ fn test_pcall_xpcall() -> Result<()> { "#, ) .exec()?; - let _ = globals - .get::<_, Function>("xpcall_recursion")? - .call::<_, ()>(()); + let _ = globals.get::("xpcall_recursion")?.call::<()>(()); Ok(()) } @@ -674,23 +709,22 @@ fn test_recursive_mut_callback_error() -> Result<()> { let lua = Lua::new(); let mut v = Some(Box::new(123)); - let f = lua.create_function_mut::<_, (), _>(move |lua, mutate: bool| { + let f = lua.create_function_mut(move |lua, mutate: bool| { if mutate { v = None; } else { // Produce a mutable reference let r = v.as_mut().unwrap(); // Whoops, this will recurse into the function and produce another mutable reference! - lua.globals().get::<_, Function>("f")?.call::<_, ()>(true)?; + lua.globals().get::("f")?.call::<()>(true)?; println!("Should not get here, mutable aliasing has occurred!"); - println!("value at {:p}", r as *mut _); - println!("value is {}", r); + println!("value at {:p} is {r}", r as *mut _); } Ok(()) })?; lua.globals().set("f", f)?; - match lua.globals().get::<_, Function>("f")?.call::<_, ()>(false) { + match lua.globals().get::("f")?.call::<()>(false) { Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { Error::CallbackError { ref cause, .. } => match *cause.as_ref() { Error::RecursiveMutCallback { .. } => {} @@ -721,13 +755,13 @@ fn test_set_metatable_nil() -> Result<()> { fn test_named_registry_value() -> Result<()> { let lua = Lua::new(); - lua.set_named_registry_value::<_, i32>("test", 42)?; + lua.set_named_registry_value("test", 42)?; let f = lua.create_function(move |lua, ()| { - assert_eq!(lua.named_registry_value::<_, i32>("test")?, 42); + assert_eq!(lua.named_registry_value::("test")?, 42); Ok(()) })?; - f.call::<_, ()>(())?; + f.call::<()>(())?; lua.unset_named_registry_value("test")?; match lua.named_registry_value("test")? { @@ -742,7 +776,7 @@ fn test_named_registry_value() -> Result<()> { fn test_registry_value() -> Result<()> { let lua = Lua::new(); - let mut r = Some(lua.create_registry_value::(42)?); + let mut r = Some(lua.create_registry_value(42)?); let f = lua.create_function_mut(move |lua, ()| { if let Some(r) = r.take() { assert_eq!(lua.registry_value::(&r)?, 42); @@ -753,14 +787,14 @@ fn test_registry_value() -> Result<()> { Ok(()) })?; - f.call::<_, ()>(())?; + f.call::<()>(())?; Ok(()) } #[test] fn test_drop_registry_value() -> Result<()> { - struct MyUserdata(Arc<()>); + struct MyUserdata(#[allow(unused)] Arc<()>); impl UserData for MyUserdata {} @@ -784,9 +818,19 @@ fn test_drop_registry_value() -> Result<()> { fn test_replace_registry_value() -> Result<()> { let lua = Lua::new(); - let key = lua.create_registry_value::(42)?; - lua.replace_registry_value(&key, "new value")?; + let mut key = lua.create_registry_value(42)?; + lua.replace_registry_value(&mut key, "new value")?; assert_eq!(lua.registry_value::(&key)?, "new value"); + lua.replace_registry_value(&mut key, Value::Nil)?; + assert_eq!(lua.registry_value::(&key)?, Value::Nil); + lua.replace_registry_value(&mut key, 123)?; + assert_eq!(lua.registry_value::(&key)?, 123); + + let mut key2 = lua.create_registry_value(Value::Nil)?; + lua.replace_registry_value(&mut key2, Value::Nil)?; + assert_eq!(lua.registry_value::(&key2)?, Value::Nil); + lua.replace_registry_value(&mut key2, "abc")?; + assert_eq!(lua.registry_value::(&key2)?, "abc"); Ok(()) } @@ -839,29 +883,79 @@ fn test_mismatched_registry_key() -> Result<()> { } #[test] +fn test_registry_value_reuse() -> Result<()> { + let lua = Lua::new(); + + let r1 = lua.create_registry_value("value1")?; + let r1_slot = format!("{r1:?}"); + drop(r1); + + // Previous slot must not be reused by nil value + let r2 = lua.create_registry_value(Value::Nil)?; + let r2_slot = format!("{r2:?}"); + assert_ne!(r1_slot, r2_slot); + drop(r2); + + // But should be reused by non-nil value + let r3 = lua.create_registry_value("value3")?; + let r3_slot = format!("{r3:?}"); + assert_eq!(r1_slot, r3_slot); + + Ok(()) +} + +#[test] +#[cfg(not(panic = "abort"))] fn test_application_data() -> Result<()> { let lua = Lua::new(); lua.set_app_data("test1"); lua.set_app_data(vec!["test2"]); + // Borrow &str immutably and Vec<&str> mutably + let s = lua.app_data_ref::<&str>().unwrap(); + let mut v = lua.app_data_mut::>().unwrap(); + v.push("test3"); + + // Insert of new data or removal should fail now + assert!(lua.try_set_app_data::(123).is_err()); + match catch_unwind(AssertUnwindSafe(|| lua.set_app_data::(123))) { + Ok(_) => panic!("expected panic"), + Err(_) => {} + } + match catch_unwind(AssertUnwindSafe(|| lua.remove_app_data::())) { + Ok(_) => panic!("expected panic"), + Err(_) => {} + } + + // Check display and debug impls + assert_eq!(format!("{s}"), "test1"); + assert_eq!(format!("{s:?}"), "\"test1\""); + + // Borrowing immutably and mutably of the same type is not allowed + assert!(lua.try_app_data_mut::<&str>().is_err()); + match catch_unwind(AssertUnwindSafe(|| lua.app_data_mut::<&str>().unwrap())) { + Ok(_) => panic!("expected panic"), + Err(_) => {} + } + assert!(lua.try_app_data_ref::>().is_err()); + drop((s, v)); + + // Test that application data is accessible from anywhere let f = lua.create_function(|lua, ()| { - { - let data1 = lua.app_data_ref::<&str>().unwrap(); - assert_eq!(*data1, "test1"); - } - let mut data2 = lua.app_data_mut::>().unwrap(); - assert_eq!(*data2, vec!["test2"]); - data2.push("test3"); + let mut data1 = lua.app_data_mut::<&str>().unwrap(); + assert_eq!(*data1, "test1"); + *data1 = "test4"; + + let data2 = lua.app_data_ref::>().unwrap(); + assert_eq!(*data2, vec!["test2", "test3"]); + Ok(()) })?; - f.call(())?; + f.call::<()>(())?; - assert_eq!(*lua.app_data_ref::<&str>().unwrap(), "test1"); - assert_eq!( - *lua.app_data_ref::>().unwrap(), - vec!["test2", "test3"] - ); + assert_eq!(*lua.app_data_ref::<&str>().unwrap(), "test4"); + assert_eq!(*lua.app_data_ref::>().unwrap(), vec!["test2", "test3"]); lua.remove_app_data::>(); assert!(matches!(lua.app_data_ref::>(), None)); @@ -870,95 +964,118 @@ fn test_application_data() -> Result<()> { } #[test] +fn test_rust_function() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + lua.load( + r#" + function lua_function() + return rust_function() + end + + -- Test to make sure chunk return is ignored + return 1 + "#, + ) + .exec()?; + + let lua_function = globals.get::("lua_function")?; + let rust_function = lua.create_function(|_, ()| Ok("hello"))?; + + globals.set("rust_function", rust_function)?; + assert_eq!(lua_function.call::(())?, "hello"); + + Ok(()) +} + +#[test] +fn test_c_function() -> Result<()> { + let lua = Lua::new(); + + extern "C-unwind" fn c_function(state: *mut mlua::lua_State) -> std::os::raw::c_int { + unsafe { + ffi::lua_pushboolean(state, 1); + ffi::lua_setglobal(state, b"c_function\0" as *const _ as *const _); + } + 0 + } + + let func = unsafe { lua.create_c_function(c_function)? }; + func.call::<()>(())?; + assert_eq!(lua.globals().get::("c_function")?, true); + + Ok(()) +} + +#[test] +#[cfg(not(target_arch = "wasm32"))] fn test_recursion() -> Result<()> { let lua = Lua::new(); let f = lua.create_function(move |lua, i: i32| { if i < 64 { - lua.globals() - .get::<_, Function>("f")? - .call::<_, ()>(i + 1)?; + lua.globals().get::("f")?.call::<()>(i + 1)?; } Ok(()) })?; - lua.globals().set("f", f.clone())?; - f.call::<_, ()>(1)?; + lua.globals().set("f", &f)?; + f.call::<()>(1)?; Ok(()) } #[test] +#[cfg(not(target_arch = "wasm32"))] fn test_too_many_returns() -> Result<()> { let lua = Lua::new(); let f = lua.create_function(|_, ()| Ok(Variadic::from_iter(1..1000000)))?; - assert!(f.call::<_, Vec>(()).is_err()); + assert!(f.call::>(()).is_err()); Ok(()) } #[test] +#[cfg(not(target_arch = "wasm32"))] fn test_too_many_arguments() -> Result<()> { let lua = Lua::new(); lua.load("function test(...) end").exec()?; let args = Variadic::from_iter(1..1000000); - assert!(lua - .globals() - .get::<_, Function>("test")? - .call::<_, ()>(args) - .is_err()); + assert!(lua.globals().get::("test")?.call::<()>(args).is_err()); Ok(()) } #[test] #[cfg(not(feature = "luajit"))] +#[cfg(not(target_arch = "wasm32"))] fn test_too_many_recursions() -> Result<()> { let lua = Lua::new(); - let f = lua - .create_function(move |lua, ()| lua.globals().get::<_, Function>("f")?.call::<_, ()>(()))?; + let f = lua.create_function(move |lua, ()| lua.globals().get::("f")?.call::<()>(()))?; - lua.globals().set("f", f.clone())?; - assert!(f.call::<_, ()>(()).is_err()); - - Ok(()) -} - -#[test] -fn test_too_many_binds() -> Result<()> { - let lua = Lua::new(); - let globals = lua.globals(); - lua.load( - r#" - function f(...) - end - "#, - ) - .exec()?; - - let concat = globals.get::<_, Function>("f")?; - assert!(concat.bind(Variadic::from_iter(1..1000000)).is_err()); - assert!(concat - .call::<_, ()>(Variadic::from_iter(1..1000000)) - .is_err()); + lua.globals().set("f", &f)?; + assert!(f.call::<()>(()).is_err()); Ok(()) } #[test] +#[cfg(not(target_arch = "wasm32"))] fn test_ref_stack_exhaustion() { match catch_unwind(AssertUnwindSafe(|| -> Result<()> { let lua = Lua::new(); let mut vals = Vec::new(); - for _ in 0..1000000 { + for _ in 0..10000000 { vals.push(lua.create_table()?); } Ok(()) })) { Ok(_) => panic!("no panic was detected"), - Err(p) => assert!(p - .downcast::() - .unwrap() - .starts_with("cannot create a Lua reference, out of auxiliary stack space")), + Err(p) => assert!( + p.downcast::() + .unwrap() + .starts_with("cannot create a Lua reference, out of auxiliary stack space") + ), } } @@ -989,10 +1106,7 @@ fn test_large_args() -> Result<()> { ) .eval()?; - assert_eq!( - f.call::<_, usize>((0..100).collect::>())?, - 4950 - ); + assert_eq!(f.call::((0..100).collect::>())?, 4950); Ok(()) } @@ -1008,7 +1122,7 @@ fn test_large_args_ref() -> Result<()> { Ok(()) })?; - f.call::<_, ()>((0..100).map(|i| i.to_string()).collect::>())?; + f.call::<()>((0..100).map(|i| i.to_string()).collect::>())?; Ok(()) } @@ -1030,7 +1144,7 @@ fn test_chunk_env() -> Result<()> { test_var = 1 "#, ) - .set_environment(env1.clone())? + .set_environment(env1.clone()) .exec()?; lua.load( @@ -1039,18 +1153,11 @@ fn test_chunk_env() -> Result<()> { test_var = 2 "#, ) - .set_environment(env2.clone())? + .set_environment(env2.clone()) .exec()?; - assert_eq!( - lua.load("test_var").set_environment(env1)?.eval::()?, - 1 - ); - - assert_eq!( - lua.load("test_var").set_environment(env2)?.eval::()?, - 2 - ); + assert_eq!(lua.load("test_var").set_environment(env1).eval::()?, 1); + assert_eq!(lua.load("test_var").set_environment(env2).eval::()?, 2); Ok(()) } @@ -1069,19 +1176,20 @@ fn test_context_thread() -> Result<()> { .into_function()?; #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "luajit52" ))] - f.call::<_, ()>(lua.current_thread())?; + f.call::<()>(lua.current_thread())?; #[cfg(any( feature = "lua51", all(feature = "luajit", not(feature = "luajit52")), feature = "luau" ))] - f.call::<_, ()>(Nil)?; + f.call::<()>(Nil)?; Ok(()) } @@ -1102,7 +1210,7 @@ fn test_context_thread_51() -> Result<()> { .eval()?, )?; - thread.resume::<_, ()>(thread.clone())?; + thread.resume::<()>(thread.clone())?; Ok(()) } @@ -1112,44 +1220,99 @@ fn test_context_thread_51() -> Result<()> { fn test_jit_version() -> Result<()> { let lua = Lua::new(); let jit: Table = lua.globals().get("jit")?; - assert!(jit - .get::<_, String>("version")? - .to_str()? - .contains("LuaJIT")); + assert!( + jit.get::("version")? + .to_str()? + .contains("LuaJIT") + ); + Ok(()) +} + +#[test] +fn test_register_module() -> Result<()> { + let lua = Lua::new(); + + let t = lua.create_table()?; + t.set("name", "my_module")?; + lua.register_module("@my_module", &t)?; + + lua.load( + r#" + local my_module = require("@my_module") + assert(my_module.name == "my_module") + "#, + ) + .exec()?; + + lua.unload_module("@my_module")?; + lua.load( + r#" + local ok, err = pcall(function() return require("@my_module") end) + assert(not ok) + "#, + ) + .exec()?; + + #[cfg(feature = "luau")] + { + // Luau registered modules must have '@' prefix + let res = lua.register_module("my_module", 123); + assert!(res.is_err()); + assert_eq!( + res.unwrap_err().to_string(), + "runtime error: module name must begin with '@'" + ); + + // Luau registered modules (aliases) are case-insensitive + let res = lua.register_module("@My_Module", &t); + assert!(res.is_ok()); + lua.load( + r#" + local my_module = require("@MY_MODule") + assert(my_module.name == "my_module") + "#, + ) + .exec()?; + } + Ok(()) } #[test] -fn test_load_from_function() -> Result<()> { +#[cfg(not(feature = "luau"))] +fn test_preload_module() -> Result<()> { let lua = Lua::new(); - let i = Arc::new(AtomicU32::new(0)); - let i2 = i.clone(); - let func = lua.create_function(move |lua, modname: String| { - i2.fetch_add(1, Ordering::Relaxed); + let loader = lua.create_function(move |lua, modname: String| { let t = lua.create_table()?; - t.set("__name", modname)?; + t.set("name", modname)?; Ok(t) })?; - let t: Table = lua.load_from_function("my_module", func.clone())?; - assert_eq!(t.get::<_, String>("__name")?, "my_module"); - assert_eq!(i.load(Ordering::Relaxed), 1); - - let _: Value = lua.load_from_function("my_module", func.clone())?; - assert_eq!(i.load(Ordering::Relaxed), 1); - - let func_nil = lua.create_function(move |_, _: String| Ok(Value::Nil))?; - let v: Value = lua.load_from_function("my_module2", func_nil)?; - assert_eq!(v, Value::Boolean(true)); + lua.preload_module("@my_module", loader.clone())?; + lua.load( + r#" + -- `my_module` is global for purposes of next test + my_module = require("@my_module") + assert(my_module.name == "@my_module") + local my_module2 = require("@my_module") + assert(my_module == my_module2) + "#, + ) + .exec() + .unwrap(); // Test unloading and loading again - lua.unload("my_module")?; - let _: Value = lua.load_from_function("my_module", func)?; - assert_eq!(i.load(Ordering::Relaxed), 2); - - // Unloading nonexistent module must not fail - lua.unload("my_module2")?; + lua.unload_module("@my_module")?; + lua.load( + r#" + local my_module3 = require("@my_module") + -- `my_module` is not equal to `my_module3` because it was reloaded + assert(my_module ~= my_module3) + "#, + ) + .exec() + .unwrap(); Ok(()) } @@ -1159,14 +1322,18 @@ fn test_inspect_stack() -> Result<()> { let lua = Lua::new(); // Not inside any function - assert!(lua.inspect_stack(0).is_none()); - - let logline = lua.create_function(|lua, msg: StdString| { - let debug = lua.inspect_stack(1).unwrap(); // caller - let source = debug.source().short_src.map(core::str::from_utf8); - let source = source.transpose().unwrap().unwrap_or("?"); - let line = debug.curr_line(); - Ok(format!("{}:{} {}", source, line, msg)) + assert!(lua.inspect_stack(0, |_| ()).is_none()); + + let logline = lua.create_function(|lua, msg: String| { + let r = lua + .inspect_stack(1, |debug| { + let source = debug.source().short_src; + let source = source.as_deref().unwrap_or("?"); + let line = debug.current_line().unwrap(); + format!("{}:{} {}", source, line, msg) + }) + .unwrap(); + Ok(r) })?; lua.globals().set("logline", logline)?; @@ -1185,7 +1352,136 @@ fn test_inspect_stack() -> Result<()> { assert(logline("world") == '[string "chunk"]:12 world') "#, ) - .set_name("chunk")? + .set_name("chunk") + .exec()?; + + let stack_info = lua.create_function(|lua, ()| { + let stack_info = lua.inspect_stack(1, |debug| debug.stack()).unwrap(); + Ok(format!("{stack_info:?}")) + })?; + lua.globals().set("stack_info", stack_info)?; + + #[cfg(any( + feature = "lua55", + feature = "lua54", + feature = "lua53", + feature = "lua52", + feature = "luau" + ))] + lua.load( + r#" + local stack_info = stack_info + local function baz(a, b, c, ...) + return stack_info() + end + assert(baz() == 'DebugStack { num_upvalues: 1, num_params: 3, is_vararg: true }') + "#, + ) + .exec()?; + + // LuaJIT does not pass this test for some reason + #[cfg(feature = "lua51")] + lua.load( + r#" + local stack_info = stack_info + local function baz(a, b, c, ...) + return stack_info() + end + assert(baz() == 'DebugStack { num_upvalues: 1 }') + "#, + ) + .exec()?; + + // Test retrieving currently running function + let running_function = + lua.create_function(|lua, ()| Ok(lua.inspect_stack(1, |debug| debug.function())))?; + lua.globals().set("running_function", running_function)?; + lua.load( + r#" + local function baz() + return running_function() + end + if jit == nil then + assert(baz() == baz) + else + -- luajit inline the "baz" function and returns the chunk itself + assert(baz() == running_function()) + end + "#, + ) + .exec()?; + + Ok(()) +} + +#[test] +fn test_traceback() -> Result<()> { + let lua = Lua::new(); + + // Test traceback at level 0 (not inside any function) + let traceback = lua.traceback(None, 0)?.to_string_lossy(); + assert!(traceback.contains("stack traceback:")); + + // Test traceback with a message prefix + let traceback = lua.traceback(Some("error occurred"), 0)?.to_string_lossy(); + assert!(traceback.starts_with("error occurred")); + assert!(traceback.contains("stack traceback:")); + + // Test traceback inside a function + let get_traceback = lua + .create_function(|lua, (msg, level): (Option, usize)| lua.traceback(msg.as_deref(), level))?; + lua.globals().set("get_traceback", get_traceback)?; + + lua.load( + r#" + local function foo() + -- Level 1 is inside foo (the caller) + local traceback = get_traceback(nil, 1) + return traceback + end + local function bar() + local result = foo() + return result + end + local function baz() + local result = bar() + return result + end + + local traceback = baz() + assert(traceback:match("in %a+ 'foo'")) + assert(traceback:match("in %a+ 'bar'")) + assert(traceback:match("in %a+ 'baz'")) + "#, + ) + .exec()?; + + // Test traceback at different levels + lua.load( + r#" + local function foo() + local tb0 = get_traceback(nil, 0) + local tb1 = get_traceback(nil, 1) + local tb2 = get_traceback(nil, 2) + return tb0, tb1, tb2 + end + local function bar() + local tb0, tb1, tb2 = foo() + return tb0, tb1, tb2 + end + + local tb0, tb1, tb2 = bar() + + assert(tb0:match("in %a+ 'get_traceback'")) + assert(tb0:match("in %a+ 'foo'")) + + assert(not tb1:match("in %a+ 'get_traceback'")) + assert(tb1:match("in %a+ 'foo'")) + + assert(not tb2:match("in %a+ 'foo'")) + assert(tb1:match("in %a+ 'bar'")) + "#, + ) .exec()?; Ok(()) @@ -1197,7 +1493,7 @@ fn test_multi_states() -> Result<()> { let f = lua.create_function(|_, g: Option| { if let Some(g) = g { - g.call(())?; + g.call::<()>(())?; } Ok(()) })?; @@ -1210,54 +1506,61 @@ fn test_multi_states() -> Result<()> { } #[test] -#[cfg(feature = "lua54")] +#[cfg(any(feature = "lua55", feature = "lua54"))] fn test_warnings() -> Result<()> { let lua = Lua::new(); - lua.set_app_data::>(Vec::new()); + lua.set_app_data::>(Vec::new()); - lua.set_warning_function(|lua, msg, tocont| { - let msg = msg.to_string_lossy().to_string(); - lua.app_data_mut::>() + lua.set_warning_function(|lua, msg, incomplete| { + lua.app_data_mut::>() .unwrap() - .push((msg, tocont)); + .push((msg.to_string(), incomplete)); Ok(()) }); - lua.warning("native warning ...", true)?; - lua.warning("finish", false)?; + lua.warning("native warning ...", true); + lua.warning("finish", false); + lua.warning("\0", false); lua.load(r#"warn("lua warning", "continue")"#).exec()?; lua.remove_warning_function(); - lua.warning("one more warning", false)?; + lua.warning("one more warning", false); - let messages = lua.app_data_ref::>().unwrap(); + let messages = lua.app_data_ref::>().unwrap(); assert_eq!( *messages, vec![ ("native warning ...".to_string(), true), ("finish".to_string(), false), + ("".to_string(), false), ("lua warning".to_string(), true), ("continue".to_string(), false), ] ); // Trigger error inside warning - lua.set_warning_function(|_, _, _| Err(Error::RuntimeError("warning error".to_string()))); + lua.set_warning_function(|_, _, _| Err(Error::runtime("warning error"))); assert!(matches!( lua.load(r#"warn("test")"#).exec(), - Err(Error::CallbackError { cause, .. }) - if matches!(*cause, Error::RuntimeError(ref err) if err == "warning error") + Err(Error::RuntimeError(ref err)) if err == "warning error" )); + // Recursive warning + lua.set_warning_function(|lua, _, _| { + lua.warning("inner", false); + Ok(()) + }); + lua.warning("hello", false); + Ok(()) } #[test] #[cfg(feature = "luajit")] -#[should_panic] -fn test_luajit_cdata() { +fn test_luajit_cdata() -> Result<()> { let lua = unsafe { Lua::unsafe_new() }; - let _v: Result = lua + + let cdata = lua .load( r#" local ffi = require("ffi") @@ -1270,5 +1573,108 @@ fn test_luajit_cdata() { return ptr "#, ) - .eval(); + .eval::()?; + assert_eq!(cdata.type_name(), "other"); + assert!(cdata.to_string()?.starts_with("cdata:")); + + Ok(()) +} + +#[test] +#[cfg(feature = "send")] +#[cfg(not(target_arch = "wasm32"))] +fn test_multi_thread() -> Result<()> { + let lua = Lua::new(); + + lua.globals().set("i", 0)?; + let func = lua.load("i = i + 1").into_function()?; + + std::thread::scope(|s| { + s.spawn(|| { + for _ in 0..5 { + func.call::<()>(()).unwrap(); + } + }); + s.spawn(|| { + for _ in 0..5 { + func.call::<()>(()).unwrap(); + } + }); + }); + + assert_eq!(lua.globals().get::("i")?, 10); + + Ok(()) +} + +#[test] +fn test_exec_raw() -> Result<()> { + let lua = Lua::new(); + + let sum = lua.create_function(|_, args: Variadic| { + let mut sum = 0; + for i in args { + sum += i; + } + Ok(sum) + })?; + lua.globals().set("sum", sum)?; + + let n: i32 = unsafe { + lua.exec_raw((), |state| { + ffi::lua_getglobal(state, b"sum\0".as_ptr() as _); + ffi::lua_pushinteger(state, 1); + ffi::lua_pushinteger(state, 7); + ffi::lua_call(state, 2, 1); + }) + }?; + assert_eq!(n, 8); + + // Test error handling + let res: Result<()> = unsafe { + lua.exec_raw("test error", |state| { + ffi::lua_error(state); + }) + }; + assert!(matches!(res, Err(Error::RuntimeError(err)) if err.contains("test error"))); + + Ok(()) +} + +#[test] +fn test_gc_drop_ref_thread() -> Result<()> { + let lua = Lua::new(); + + let t = lua.create_table()?; + lua.create_function(move |_, ()| { + _ = &t; + Ok(()) + })?; + + for _ in 0..10000 { + // GC will run eventually to collect the function and the table above + lua.create_table()?; + } + + Ok(()) +} + +#[cfg(not(feature = "luau"))] +#[test] +fn test_get_or_init_from_ptr() -> Result<()> { + // This would not work with Luau, the state must be init by mlua internally + let state = unsafe { ffi::luaL_newstate() }; + + let mut lua = unsafe { Lua::get_or_init_from_ptr(state) }; + lua.globals().set("hello", "world678")?; + + // The same Lua instance must be returned + lua = unsafe { Lua::get_or_init_from_ptr(state) }; + assert_eq!(lua.globals().get::("hello")?, "world678"); + + unsafe { ffi::lua_close(state) }; + + // Lua must not be accessed after closing + + Ok(()) } diff --git a/tests/thread.rs b/tests/thread.rs index b664053f..98b861f8 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -1,6 +1,6 @@ use std::panic::catch_unwind; -use mlua::{Error, Function, Lua, Result, Thread, ThreadStatus}; +use mlua::{Error, Function, IntoLua, Lua, Result, Thread, Value}; #[test] fn test_thread() -> Result<()> { @@ -21,17 +21,17 @@ fn test_thread() -> Result<()> { .eval()?, )?; - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(0)?, 0); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(1)?, 1); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(2)?, 3); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(3)?, 6); - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(4)?, 10); - assert_eq!(thread.status(), ThreadStatus::Unresumable); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(0)?, 0); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(1)?, 1); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(2)?, 3); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(3)?, 6); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(4)?, 10); + assert!(thread.is_finished()); let accumulate = lua.create_thread( lua.load( @@ -47,12 +47,12 @@ fn test_thread() -> Result<()> { )?; for i in 0..4 { - accumulate.resume::<_, ()>(i)?; + accumulate.resume::<()>(i)?; } - assert_eq!(accumulate.resume::<_, i64>(4)?, 10); - assert_eq!(accumulate.status(), ThreadStatus::Resumable); - assert!(accumulate.resume::<_, ()>("error").is_err()); - assert_eq!(accumulate.status(), ThreadStatus::Error); + assert_eq!(accumulate.resume::(4)?, 10); + assert!(accumulate.is_resumable()); + assert!(accumulate.resume::<()>("error").is_err()); + assert!(accumulate.is_error()); let thread = lua .load( @@ -65,8 +65,8 @@ fn test_thread() -> Result<()> { "#, ) .eval::()?; - assert_eq!(thread.status(), ThreadStatus::Resumable); - assert_eq!(thread.resume::<_, i64>(())?, 42); + assert!(thread.is_resumable()); + assert_eq!(thread.resume::(())?, 42); let thread: Thread = lua .load( @@ -81,45 +81,54 @@ fn test_thread() -> Result<()> { ) .eval()?; - assert_eq!(thread.resume::<_, u32>(42)?, 123); - assert_eq!(thread.resume::<_, u32>(43)?, 987); + assert_eq!(thread.resume::(42)?, 123); + assert_eq!(thread.resume::(43)?, 987); - match thread.resume::<_, u32>(()) { - Err(Error::CoroutineInactive) => {} + match thread.resume::(()) { + Err(Error::CoroutineUnresumable) => {} Err(_) => panic!("resuming dead coroutine error is not CoroutineInactive kind"), _ => panic!("resuming dead coroutine did not return error"), } + // Already running thread must be unresumable + let thread = lua.create_thread(lua.create_function(|lua, ()| { + assert!(lua.current_thread().is_running()); + let result = lua.current_thread().resume::<()>(()); + assert!( + matches!(result, Err(Error::CoroutineUnresumable)), + "unexpected result: {result:?}", + ); + Ok(()) + })?)?; + let result = thread.resume::<()>(()); + assert!(result.is_ok(), "unexpected result: {result:?}"); + Ok(()) } #[test] -#[cfg(any( - feature = "lua54", - all(feature = "luajit", feature = "vendored"), - feature = "luau", -))] fn test_thread_reset() -> Result<()> { use mlua::{AnyUserData, UserData}; use std::sync::Arc; let lua = Lua::new(); - struct MyUserData(Arc<()>); + struct MyUserData(#[allow(unused)] Arc<()>); impl UserData for MyUserData {} let arc = Arc::new(()); let func: Function = lua.load(r#"function(ud) coroutine.yield(ud) end"#).eval()?; - let thread = lua.create_thread(func.clone())?; + let thread = lua.create_thread(lua.load("return 0").into_function()?)?; // Dummy function first + assert!(thread.reset(func.clone()).is_ok()); for _ in 0..2 { - assert_eq!(thread.status(), ThreadStatus::Resumable); - let _ = thread.resume::<_, AnyUserData>(MyUserData(arc.clone()))?; - assert_eq!(thread.status(), ThreadStatus::Resumable); + assert!(thread.is_resumable()); + let _ = thread.resume::(MyUserData(arc.clone()))?; + assert!(thread.is_resumable()); assert_eq!(Arc::strong_count(&arc), 2); - thread.resume::<_, ()>(())?; - assert_eq!(thread.status(), ThreadStatus::Unresumable); + thread.resume::<()>(())?; + assert!(thread.is_finished()); thread.reset(func.clone())?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&arc), 1); @@ -128,28 +137,39 @@ fn test_thread_reset() -> Result<()> { // Check for errors let func: Function = lua.load(r#"function(ud) error("test error") end"#).eval()?; let thread = lua.create_thread(func.clone())?; - let _ = thread.resume::<_, AnyUserData>(MyUserData(arc.clone())); - assert_eq!(thread.status(), ThreadStatus::Error); + let _ = thread.resume::(MyUserData(arc.clone())); + assert!(thread.is_error()); assert_eq!(Arc::strong_count(&arc), 2); - #[cfg(feature = "lua54")] + #[cfg(any(feature = "lua55", feature = "lua54"))] { assert!(thread.reset(func.clone()).is_err()); // Reset behavior has changed in Lua v5.4.4 // It's became possible to force reset thread by popping error object - assert!(matches!( - thread.status(), - ThreadStatus::Unresumable | ThreadStatus::Error - )); - // Would pass in 5.4.4 - // assert!(thread.reset(func.clone()).is_ok()); - // assert_eq!(thread.status(), ThreadStatus::Resumable); + assert!(thread.is_finished()); + assert!(thread.reset(func.clone()).is_ok()); + assert!(thread.is_resumable()); } - #[cfg(any(feature = "lua54", feature = "luau"))] + #[cfg(any(feature = "lua55", feature = "lua54", feature = "luau"))] { assert!(thread.reset(func.clone()).is_ok()); - assert_eq!(thread.status(), ThreadStatus::Resumable); + assert!(thread.is_resumable()); } + // Try reset running thread + let thread = lua.create_thread(lua.create_function(|lua, ()| { + let this = lua.current_thread(); + this.reset(lua.create_function(|_, ()| Ok(()))?)?; + Ok(()) + })?)?; + let result = thread.resume::<()>(()); + assert!( + matches!(result, Err(Error::CallbackError{ ref cause, ..}) + if matches!(cause.as_ref(), Error::RuntimeError(err) + if err == "cannot reset a running thread") + ), + "unexpected result: {result:?}", + ); + Ok(()) } @@ -161,6 +181,7 @@ fn test_coroutine_from_closure() -> Result<()> { lua.globals().set("main", thrd_main)?; #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -173,12 +194,13 @@ fn test_coroutine_from_closure() -> Result<()> { .load("coroutine.create(function(...) return main(unpack(arg)) end)") .eval()?; - thrd.resume::<_, ()>(())?; + thrd.resume::<()>(())?; Ok(()) } #[test] +#[cfg(not(panic = "abort"))] fn test_coroutine_panic() { match catch_unwind(|| -> Result<()> { // check that coroutines propagate panics correctly @@ -186,7 +208,7 @@ fn test_coroutine_panic() { let thrd_main = lua.create_function(|_, ()| -> Result<()> { panic!("test_panic"); })?; - lua.globals().set("main", thrd_main.clone())?; + lua.globals().set("main", &thrd_main)?; let thrd: Thread = lua.create_thread(thrd_main)?; thrd.resume(()) }) { @@ -194,3 +216,62 @@ fn test_coroutine_panic() { Err(p) => assert!(*p.downcast::<&str>().unwrap() == "test_panic"), } } + +#[test] +fn test_thread_pointer() -> Result<()> { + let lua = Lua::new(); + + let func = lua.load("return 123").into_function()?; + let thread = lua.create_thread(func.clone())?; + + assert_eq!(thread.to_pointer(), thread.clone().to_pointer()); + assert_ne!(thread.to_pointer(), lua.current_thread().to_pointer()); + + Ok(()) +} + +#[test] +#[cfg(feature = "luau")] +fn test_thread_resume_error() -> Result<()> { + let lua = Lua::new(); + + let thread = lua + .load( + r#" + coroutine.create(function() + local ok, err = pcall(coroutine.yield, 123) + assert(not ok, "yield should fail") + assert(err == "myerror", "unexpected error: " .. tostring(err)) + return "success" + end) + "#, + ) + .eval::()?; + + assert_eq!(thread.resume::(())?, 123); + let status = thread.resume_error::("myerror").unwrap(); + assert_eq!(status, "success"); + + Ok(()) +} + +#[test] +fn test_thread_resume_bad_arg() -> Result<()> { + let lua = Lua::new(); + + struct BadArg; + + impl IntoLua for BadArg { + fn into_lua(self, _lua: &Lua) -> Result { + Err(Error::runtime("bad arg")) + } + } + + let f = lua.create_thread(lua.create_function(|_, ()| Ok("okay"))?)?; + let res = f.resume::<()>((123, BadArg)); + assert!(matches!(res, Err(Error::RuntimeError(msg)) if msg == "bad arg")); + let res = f.resume::(()).unwrap(); + assert_eq!(res, "okay"); + + Ok(()) +} diff --git a/tests/types.rs b/tests/types.rs index 72f9484f..6475acd6 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -1,6 +1,6 @@ use std::os::raw::c_void; -use mlua::{Function, LightUserData, Lua, Result}; +use mlua::{Error, Function, LightUserData, Lua, LuaString, Number, Result, Thread}; #[test] fn test_lightuserdata() -> Result<()> { @@ -17,10 +17,132 @@ fn test_lightuserdata() -> Result<()> { .exec()?; let res = globals - .get::<_, Function>("id")? - .call::<_, LightUserData>(LightUserData(42 as *mut c_void))?; + .get::("id")? + .call::(LightUserData(42 as *mut c_void))?; assert_eq!(res, LightUserData(42 as *mut c_void)); Ok(()) } + +#[test] +fn test_boolean_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set("__add", Function::wrap(|a, b| Ok::<_, mlua::Error>(a || b)))?; + assert_eq!(lua.type_metatable::(), None); + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::().unwrap(), mt); + + lua.load(r#"assert(true + true == true)"#).exec().unwrap(); + lua.load(r#"assert(true + false == true)"#).exec().unwrap(); + lua.load(r#"assert(false + true == true)"#).exec().unwrap(); + lua.load(r#"assert(false + false == false)"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_lightuserdata_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__add", + Function::wrap(|a: LightUserData, b: LightUserData| { + Ok::<_, Error>(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void)) + }), + )?; + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::().unwrap(), mt); + + let res = lua + .load( + r#" + local a, b = ... + return a + b + "#, + ) + .call::(( + LightUserData(42 as *mut c_void), + LightUserData(100 as *mut c_void), + )) + .unwrap(); + assert_eq!(res, LightUserData(142 as *mut c_void)); + + Ok(()) +} + +#[test] +fn test_number_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__call", + Function::wrap(|n1: f64, n2: f64| Ok::<_, Error>(n1 * n2)), + )?; + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::().unwrap(), mt); + + lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap(); + lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_string_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__add", + Function::wrap(|a: String, b: String| Ok::<_, Error>(format!("{a}{b}"))), + )?; + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::().unwrap(), mt); + + lua.load(r#"assert(("foo" + "bar") == "foobar")"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_function_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__index", + Function::wrap(|_: Function, key: String| Ok::<_, Error>(format!("function.{key}"))), + )?; + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::(), Some(mt)); + + lua.load(r#"assert((function() end).foo == "function.foo")"#) + .exec() + .unwrap(); + + Ok(()) +} + +#[test] +fn test_thread_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__index", + Function::wrap(|_: Thread, key: String| Ok::<_, Error>(format!("thread.{key}"))), + )?; + lua.set_type_metatable::(Some(mt.clone())); + assert_eq!(lua.type_metatable::(), Some(mt)); + + lua.load(r#"assert((coroutine.create(function() end)).foo == "thread.foo")"#) + .exec() + .unwrap(); + + Ok(()) +} diff --git a/tests/userdata.rs b/tests/userdata.rs index ce9d9d0e..c5233557 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -1,23 +1,17 @@ +use std::any::TypeId; +use std::collections::HashMap; use std::sync::Arc; -#[cfg(not(feature = "parking_lot"))] -use std::sync::{Mutex, RwLock}; -#[cfg(feature = "parking_lot")] -use parking_lot::{Mutex, RwLock}; - -#[cfg(not(feature = "send"))] -use std::{cell::RefCell, rc::Rc}; - -#[cfg(feature = "lua54")] +#[cfg(any(feature = "lua55", feature = "lua54"))] use std::sync::atomic::{AtomicI64, Ordering}; use mlua::{ - AnyUserData, Error, ExternalError, Function, Lua, MetaMethod, Nil, Result, String, UserData, - UserDataFields, UserDataMethods, Value, + AnyUserData, Error, ExternalError, Function, Lua, LuaString, MetaMethod, Nil, ObjectLike, Result, + UserData, UserDataFields, UserDataMethods, UserDataOwned, UserDataRef, UserDataRegistry, Value, Variadic, }; #[test] -fn test_user_data() -> Result<()> { +fn test_userdata() -> Result<()> { struct UserData1(i64); struct UserData2(Box); @@ -29,9 +23,11 @@ fn test_user_data() -> Result<()> { let userdata2 = lua.create_userdata(UserData2(Box::new(2)))?; assert!(userdata1.is::()); + assert!(userdata1.type_id() == Some(TypeId::of::())); assert!(!userdata1.is::()); assert!(userdata2.is::()); assert!(!userdata2.is::()); + assert!(userdata2.type_id() == Some(TypeId::of::())); assert_eq!(userdata1.borrow::()?.0, 1); assert_eq!(*userdata2.borrow::()?.0, 2); @@ -41,11 +37,11 @@ fn test_user_data() -> Result<()> { #[test] fn test_methods() -> Result<()> { - #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "serde", derive(serde::Serialize))] struct MyUserData(i64); impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("get_value", |_, data, ()| Ok(data.0)); methods.add_method_mut("set_value", |_, data, args| { data.0 = args; @@ -56,7 +52,7 @@ fn test_methods() -> Result<()> { fn check_methods(lua: &Lua, userdata: AnyUserData) -> Result<()> { let globals = lua.globals(); - globals.set("userdata", userdata.clone())?; + globals.set("userdata", &userdata)?; lua.load( r#" function get_it() @@ -69,13 +65,13 @@ fn test_methods() -> Result<()> { "#, ) .exec()?; - let get = globals.get::<_, Function>("get_it")?; - let set = globals.get::<_, Function>("set_it")?; - assert_eq!(get.call::<_, i64>(())?, 42); + let get = globals.get::("get_it")?; + let set = globals.get::("set_it")?; + assert_eq!(get.call::(())?, 42); userdata.borrow_mut::()?.0 = 64; - assert_eq!(get.call::<_, i64>(())?, 64); - set.call::<_, ()>(100)?; - assert_eq!(get.call::<_, i64>(())?, 100); + assert_eq!(get.call::(())?, 64); + set.call::<()>(100)?; + assert_eq!(get.call::(())?, 100); Ok(()) } @@ -84,39 +80,65 @@ fn test_methods() -> Result<()> { check_methods(&lua, lua.create_userdata(MyUserData(42))?)?; // Additionally check serializable userdata - #[cfg(feature = "serialize")] + #[cfg(feature = "serde")] check_methods(&lua, lua.create_ser_userdata(MyUserData(42))?)?; Ok(()) } +#[test] +fn test_method_variadic() -> Result<()> { + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { + methods.add_method("get", |_, data, ()| Ok(data.0)); + methods.add_method_mut("add", |_, data, vals: Variadic| { + data.0 += vals.into_iter().sum::(); + Ok(()) + }); + } + } + + let lua = Lua::new(); + let globals = lua.globals(); + globals.set("userdata", MyUserData(0))?; + lua.load("userdata:add(1, 5, -10)").exec()?; + let ud: UserDataRef = globals.get("userdata")?; + assert_eq!(ud.0, -4); + + Ok(()) +} + #[test] fn test_metamethods() -> Result<()> { #[derive(Copy, Clone)] struct MyUserData(i64); impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("get", |_, data, ()| Ok(data.0)); methods.add_meta_function( MetaMethod::Add, - |_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 + rhs.0)), + |_, (lhs, rhs): (UserDataRef, UserDataRef)| Ok(MyUserData(lhs.0 + rhs.0)), ); methods.add_meta_function( MetaMethod::Sub, - |_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)), + |_, (lhs, rhs): (UserDataRef, UserDataRef)| Ok(MyUserData(lhs.0 - rhs.0)), ); - methods.add_meta_function(MetaMethod::Eq, |_, (lhs, rhs): (MyUserData, MyUserData)| { - Ok(lhs.0 == rhs.0) - }); - methods.add_meta_method(MetaMethod::Index, |_, data, index: String| { + methods.add_meta_function( + MetaMethod::Eq, + |_, (lhs, rhs): (UserDataRef, UserDataRef)| Ok(lhs.0 == rhs.0), + ); + methods.add_meta_method(MetaMethod::Index, |_, data, index: LuaString| { if index.to_str()? == "inner" { Ok(data.0) } else { - Err("no such custom index".to_lua_err()) + Err("no such custom index".into_lua_err()) } }); #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -124,7 +146,7 @@ fn test_metamethods() -> Result<()> { ))] methods.add_meta_method(MetaMethod::Pairs, |lua, data, ()| { use std::iter::FromIterator; - let stateless_iter = lua.create_function(|_, (data, i): (MyUserData, i64)| { + let stateless_iter = lua.create_function(|_, (data, i): (UserDataRef, i64)| { let i = i + 1; if i <= data.0 { return Ok(mlua::Variadic::from_iter(vec![i, i])); @@ -142,11 +164,14 @@ fn test_metamethods() -> Result<()> { globals.set("userdata2", MyUserData(3))?; globals.set("userdata3", MyUserData(3))?; assert_eq!( - lua.load("userdata1 + userdata2").eval::()?.0, + lua.load("userdata1 + userdata2") + .eval::>()? + .0, 10 ); #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", @@ -166,43 +191,49 @@ fn test_metamethods() -> Result<()> { ) .eval::()?; - assert_eq!(lua.load("userdata1 - userdata2").eval::()?.0, 4); + assert_eq!( + lua.load("userdata1 - userdata2") + .eval::>()? + .0, + 4 + ); assert_eq!(lua.load("userdata1:get()").eval::()?, 7); assert_eq!(lua.load("userdata2.inner").eval::()?, 3); assert!(lua.load("userdata2.nonexist_field").eval::<()>().is_err()); #[cfg(any( + feature = "lua55", feature = "lua54", feature = "lua53", feature = "lua52", feature = "luajit52" ))] - assert_eq!(pairs_it.call::<_, i64>(())?, 28); + assert_eq!(pairs_it.call::(())?, 28); let userdata2: Value = globals.get("userdata2")?; let userdata3: Value = globals.get("userdata3")?; assert!(lua.load("userdata2 == userdata3").eval::()?); assert!(userdata2 != userdata3); // because references are differ - assert!(userdata2.equals(userdata3)?); + assert!(userdata2.equals(&userdata3)?); let userdata1: AnyUserData = globals.get("userdata1")?; - assert!(userdata1.get_metatable()?.contains(MetaMethod::Add)?); - assert!(userdata1.get_metatable()?.contains(MetaMethod::Sub)?); - assert!(userdata1.get_metatable()?.contains(MetaMethod::Index)?); - assert!(!userdata1.get_metatable()?.contains(MetaMethod::Pow)?); + assert!(userdata1.metatable()?.contains(MetaMethod::Add)?); + assert!(userdata1.metatable()?.contains(MetaMethod::Sub)?); + assert!(userdata1.metatable()?.contains(MetaMethod::Index)?); + assert!(!userdata1.metatable()?.contains(MetaMethod::Pow)?); Ok(()) } +#[cfg(any(feature = "lua55", feature = "lua54"))] #[test] -#[cfg(feature = "lua54")] fn test_metamethod_close() -> Result<()> { #[derive(Clone)] struct MyUserData(Arc); impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("get", |_, data, ()| Ok(data.0.load(Ordering::Relaxed))); methods.add_meta_method(MetaMethod::Close, |_, data, _err: Value| { data.0.store(0, Ordering::Relaxed); @@ -248,9 +279,9 @@ fn test_gc_userdata() -> Result<()> { } impl UserData for MyUserdata { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("access", |_, this, ()| { - assert!(this.id == 123); + assert_eq!(this.id, 123); Ok(()) }); } @@ -259,8 +290,8 @@ fn test_gc_userdata() -> Result<()> { let lua = Lua::new(); lua.globals().set("userdata", MyUserdata { id: 123 })?; - assert!(lua - .load( + assert!( + lua.load( r#" local tbl = setmetatable({ userdata = userdata @@ -276,7 +307,8 @@ fn test_gc_userdata() -> Result<()> { "# ) .exec() - .is_err()); + .is_err() + ); Ok(()) } @@ -287,12 +319,12 @@ fn test_userdata_take() -> Result<()> { struct MyUserdata(Arc); impl UserData for MyUserdata { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("num", |_, this, ()| Ok(*this.0)) } } - #[cfg(feature = "serialize")] + #[cfg(feature = "serde")] impl serde::Serialize for MyUserdata { fn serialize(&self, serializer: S) -> std::result::Result where @@ -303,23 +335,21 @@ fn test_userdata_take() -> Result<()> { } fn check_userdata_take(lua: &Lua, userdata: AnyUserData, rc: Arc) -> Result<()> { - lua.globals().set("userdata", userdata.clone())?; + lua.globals().set("userdata", &userdata)?; assert_eq!(Arc::strong_count(&rc), 3); - let userdata_copy = userdata.clone(); { let _value = userdata.borrow::()?; // We should not be able to take userdata if it's borrowed - match userdata_copy.take::() { + match userdata.take::() { Err(Error::UserDataBorrowMutError) => {} r => panic!("expected `UserDataBorrowMutError` error, got {:?}", r), } } - let value = userdata_copy.take::()?; + let value = userdata.take::()?; assert_eq!(*value.0, 18); drop(value); - lua.gc_collect()?; - assert_eq!(Arc::strong_count(&rc), 1); + assert_eq!(Arc::strong_count(&rc), 2); match userdata.borrow::() { Err(Error::UserDataDestructed) => {} @@ -327,11 +357,20 @@ fn test_userdata_take() -> Result<()> { } match lua.load("userdata:num()").exec() { Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::CallbackDestructed => {} - err => panic!("expected `CallbackDestructed`, got {:?}", err), + Error::UserDataDestructed => {} + err => panic!("expected `UserDataDestructed`, got {:?}", err), }, r => panic!("improper return for destructed userdata: {:?}", r), } + + assert!(!userdata.is::()); + + drop(userdata); + lua.globals().raw_remove("userdata")?; + lua.gc_collect()?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&rc), 1); + Ok(()) } @@ -343,7 +382,7 @@ fn test_userdata_take() -> Result<()> { check_userdata_take(&lua, userdata, rc)?; // Additionally check serializable userdata - #[cfg(feature = "serialize")] + #[cfg(feature = "serde")] { let rc = Arc::new(18); let userdata = lua.create_ser_userdata(MyUserdata(rc.clone()))?; @@ -356,9 +395,20 @@ fn test_userdata_take() -> Result<()> { #[test] fn test_userdata_destroy() -> Result<()> { - struct MyUserdata(Arc<()>); + struct MyUserdata(#[allow(unused)] Arc<()>); - impl UserData for MyUserdata {} + impl UserData for MyUserdata { + fn add_methods>(methods: &mut M) { + methods.add_method("try_destroy", |lua, _this, ()| { + let ud = lua.globals().get::("ud")?; + match ud.destroy() { + Err(Error::UserDataBorrowMutError) => {} + r => panic!("expected `UserDataBorrowMutError` error, got {:?}", r), + } + Ok(()) + }); + } + } let rc = Arc::new(()); @@ -376,6 +426,56 @@ fn test_userdata_destroy() -> Result<()> { assert_eq!(Arc::strong_count(&rc), 1); + let ud = lua.create_userdata(MyUserdata(rc.clone()))?; + assert_eq!(Arc::strong_count(&rc), 2); + let ud_ref = ud.borrow::()?; + // With active `UserDataRef` this methods only marks userdata as destructed + // without running destructor + ud.destroy().unwrap(); + assert_eq!(Arc::strong_count(&rc), 2); + drop(ud_ref); + assert_eq!(Arc::strong_count(&rc), 1); + + // We cannot destroy (internally) borrowed userdata + let ud = lua.create_userdata(MyUserdata(rc.clone()))?; + lua.globals().set("ud", &ud)?; + lua.load("ud:try_destroy()").exec().unwrap(); + ud.destroy().unwrap(); + assert_eq!(Arc::strong_count(&rc), 1); + + Ok(()) +} + +#[test] +fn test_userdata_method_once() -> Result<()> { + struct MyUserdata(Arc); + + impl UserData for MyUserdata { + fn add_methods>(methods: &mut M) { + methods.add_method_once("take_value", |_, this, ()| Ok(*this.0)); + } + } + + let lua = Lua::new(); + let rc = Arc::new(42); + let userdata = lua.create_userdata(MyUserdata(rc.clone()))?; + lua.globals().set("userdata", &userdata)?; + + // Control userdata + let userdata2 = lua.create_userdata(MyUserdata(rc.clone()))?; + lua.globals().set("userdata2", userdata2)?; + + assert_eq!(lua.load("userdata:take_value()").eval::()?, 42); + match lua.load("userdata2.take_value(userdata)").eval::() { + Err(Error::CallbackError { cause, .. }) => { + let err = cause.to_string(); + assert!(err.contains("bad argument `self` to `MyUserdata.take_value`")); + assert!(err.contains("userdata has been destructed")); + } + r => panic!("expected Err(CallbackError), got {r:?}"), + } + assert_eq!(Arc::strong_count(&rc), 2); + Ok(()) } @@ -391,21 +491,22 @@ fn test_user_values() -> Result<()> { ud.set_nth_user_value(1, "hello")?; ud.set_nth_user_value(2, "world")?; ud.set_nth_user_value(65535, 321)?; - assert_eq!(ud.get_nth_user_value::(1)?, "hello"); - assert_eq!(ud.get_nth_user_value::(2)?, "world"); - assert_eq!(ud.get_nth_user_value::(3)?, Value::Nil); - assert_eq!(ud.get_nth_user_value::(65535)?, 321); + assert_eq!(ud.nth_user_value::(1)?, "hello"); + assert_eq!(ud.nth_user_value::(2)?, "world"); + assert_eq!(ud.nth_user_value::(3)?, Value::Nil); + assert_eq!(ud.nth_user_value::(65535)?, 321); - assert!(ud.get_nth_user_value::(0).is_err()); - assert!(ud.get_nth_user_value::(65536).is_err()); + assert!(ud.nth_user_value::(0).is_err()); + assert!(ud.nth_user_value::(65536).is_err()); // Named user values + let ud = lua.create_userdata(MyUserData)?; ud.set_named_user_value("name", "alex")?; ud.set_named_user_value("age", 10)?; - assert_eq!(ud.get_named_user_value::<_, String>("name")?, "alex"); - assert_eq!(ud.get_named_user_value::<_, i32>("age")?, 10); - assert_eq!(ud.get_named_user_value::<_, Value>("nonexist")?, Value::Nil); + assert_eq!(ud.named_user_value::("name")?, "alex"); + assert_eq!(ud.named_user_value::("age")?, 10); + assert_eq!(ud.named_user_value::("nonexist")?, Value::Nil); Ok(()) } @@ -415,10 +516,8 @@ fn test_functions() -> Result<()> { struct MyUserData(i64); impl UserData for MyUserData { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_function("get_value", |_, ud: AnyUserData| { - Ok(ud.borrow::()?.0) - }); + fn add_methods>(methods: &mut M) { + methods.add_function("get_value", |_, ud: AnyUserData| Ok(ud.borrow::()?.0)); methods.add_function_mut("set_value", |_, (ud, value): (AnyUserData, i64)| { ud.borrow_mut::()?.0 = value; Ok(()) @@ -430,7 +529,7 @@ fn test_functions() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); let userdata = lua.create_userdata(MyUserData(42))?; - globals.set("userdata", userdata.clone())?; + globals.set("userdata", &userdata)?; lua.load( r#" function get_it() @@ -447,42 +546,46 @@ fn test_functions() -> Result<()> { "#, ) .exec()?; - let get = globals.get::<_, Function>("get_it")?; - let set = globals.get::<_, Function>("set_it")?; - let get_constant = globals.get::<_, Function>("get_constant")?; - assert_eq!(get.call::<_, i64>(())?, 42); + let get = globals.get::("get_it")?; + let set = globals.get::("set_it")?; + let get_constant = globals.get::("get_constant")?; + assert_eq!(get.call::(())?, 42); userdata.borrow_mut::()?.0 = 64; - assert_eq!(get.call::<_, i64>(())?, 64); - set.call::<_, ()>(100)?; - assert_eq!(get.call::<_, i64>(())?, 100); - assert_eq!(get_constant.call::<_, i64>(())?, 7); + assert_eq!(get.call::(())?, 64); + set.call::<()>(100)?; + assert_eq!(get.call::(())?, 100); + assert_eq!(get_constant.call::(())?, 7); Ok(()) } #[test] fn test_fields() -> Result<()> { + let lua = Lua::new(); + let globals = lua.globals(); + #[derive(Copy, Clone)] struct MyUserData(i64); impl UserData for MyUserData { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { + fields.add_field("static", "constant"); fields.add_field_method_get("val", |_, data| Ok(data.0)); fields.add_field_method_set("val", |_, data, val| { data.0 = val; Ok(()) }); - // Use userdata "uservalue" storage - fields.add_field_function_get("uval", |_, ud| ud.get_user_value::>()); - fields - .add_field_function_set("uval", |_, ud, s| ud.set_user_value::>(s)); - - fields.add_meta_field_with(MetaMethod::Index, |lua| { - let index = lua.create_table()?; - index.set("f", 321)?; - Ok(index) + // Field that emulates method + fields.add_field_function_get("val_fget", |lua, ud| { + lua.create_function(move |_, ()| Ok(ud.borrow::()?.0)) }); + + // Use userdata "uservalue" storage + fields.add_field_function_get("uval", |_, ud| ud.user_value::>()); + fields.add_field_function_set("uval", |_, ud, s: Option| ud.set_user_value(s)); + + fields.add_meta_field(MetaMethod::Index, HashMap::from([("f", 321)])); fields.add_meta_field_with(MetaMethod::NewIndex, |lua| { lua.create_function(|lua, (_, field, val): (AnyUserData, String, Value)| { lua.globals().set(field, val)?; @@ -490,16 +593,20 @@ fn test_fields() -> Result<()> { }) }) } + + fn add_methods>(methods: &mut M) { + methods.add_method("dummy", |_, _, ()| Ok(())); + } } - let lua = Lua::new(); - let globals = lua.globals(); globals.set("ud", MyUserData(7))?; lua.load( r#" + assert(ud.static == "constant") assert(ud.val == 7) ud.val = 10 assert(ud.val == 10) + assert(ud:val_fget() == 10) assert(ud.uval == nil) ud.uval = "hello" @@ -513,41 +620,67 @@ fn test_fields() -> Result<()> { ) .exec()?; + // Case: fields + __index metamethod (function) + struct MyUserData2(i64); + + impl UserData for MyUserData2 { + fn add_fields>(fields: &mut F) { + fields.add_field("z", 0); + fields.add_field_method_get("x", |_, data| Ok(data.0)); + } + + fn add_methods>(methods: &mut M) { + methods.add_meta_method(MetaMethod::Index, |_, _, name: LuaString| { + match name.to_str()?.as_ref() { + "y" => Ok(Some(-1)), + _ => Ok(None), + } + }); + } + } + + globals.set("ud", MyUserData2(1))?; + lua.load( + r#" + assert(ud.x == 1) + assert(ud.y == -1) + assert(ud.z == 0) + "#, + ) + .exec()?; + Ok(()) } #[test] fn test_metatable() -> Result<()> { #[derive(Copy, Clone)] - struct MyUserData(i64); + struct MyUserData; impl UserData for MyUserData { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - fields.add_meta_field_with("__type_name", |_| Ok("MyUserData")); - } - - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_function("my_type_name", |_, data: AnyUserData| { - let metatable = data.get_metatable()?; - metatable.get::<_, String>("__type_name") + let metatable = data.metatable()?; + metatable.get::(MetaMethod::Type) }); } } let lua = Lua::new(); let globals = lua.globals(); - globals.set("ud", MyUserData(7))?; - lua.load( - r#" - assert(ud:my_type_name() == "MyUserData") - "#, - ) - .exec()?; + globals.set("ud", MyUserData)?; + lua.load(r#"assert(ud:my_type_name() == "MyUserData")"#).exec()?; + + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53", feature = "luau"))] + lua.load(r#"assert(tostring(ud):sub(1, 11) == "MyUserData:")"#) + .exec()?; + #[cfg(feature = "luau")] + lua.load(r#"assert(typeof(ud) == "MyUserData")"#).exec()?; let ud: AnyUserData = globals.get("ud")?; - let metatable = ud.get_metatable()?; + let metatable = ud.metatable()?; - match metatable.get::<_, Value>("__gc") { + match metatable.get::("__gc") { Ok(_) => panic!("expected MetaMethodRestricted, got no error"), Err(Error::MetaMethodRestricted(_)) => {} Err(e) => panic!("expected MetaMethodRestricted, got {:?}", e), @@ -561,98 +694,814 @@ fn test_metatable() -> Result<()> { let mut methods = metatable .pairs() - .into_iter() .map(|kv: Result<(_, Value)>| Ok(kv?.0)) .collect::>>()?; - methods.sort_by_cached_key(|k| k.name().to_owned()); - assert_eq!(methods, vec![MetaMethod::Index, "__type_name".into()]); + methods.sort(); + assert_eq!(methods, vec!["__index", MetaMethod::Type.name()]); #[derive(Copy, Clone)] - struct MyUserData2(i64); + struct MyUserData2; impl UserData for MyUserData2 { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { fields.add_meta_field_with("__index", |_| Ok(1)); } } - match lua.create_userdata(MyUserData2(1)) { + match lua.create_userdata(MyUserData2) { Ok(_) => panic!("expected MetaMethodTypeError, got no error"), Err(Error::MetaMethodTypeError { .. }) => {} Err(e) => panic!("expected MetaMethodTypeError, got {:?}", e), } + #[derive(Copy, Clone)] + struct MyUserData3; + + impl UserData for MyUserData3 { + fn add_fields>(fields: &mut F) { + fields.add_meta_field_with(MetaMethod::Type, |_| Ok("CustomName")); + } + } + + let ud = lua.create_userdata(MyUserData3)?; + let metatable = ud.metatable()?; + assert_eq!( + metatable.get::(MetaMethod::Type)?.to_str()?, + "CustomName" + ); + + Ok(()) +} + +#[test] +fn test_userdata_type_name() -> Result<()> { + struct MyUserData; + impl UserData for MyUserData {} + + struct MyUserdataCustom; + impl UserData for MyUserdataCustom { + fn add_fields>(fields: &mut F) { + fields.add_meta_field_with(MetaMethod::Type, |_| Ok("MyCustomName")); + } + } + + // mlua always sets __name/__type; override with a non-string to test the "userdata" fallback + struct MyUserdataInvalid; + impl UserData for MyUserdataInvalid { + fn add_fields>(fields: &mut F) { + fields.add_meta_field_with(MetaMethod::Type, |_| Ok(42_i64)); + } + } + + let lua = Lua::new(); + + // Default is the Rust type name + let ud = lua.create_userdata(MyUserData)?; + assert_eq!(ud.type_name()?, "MyUserData"); + + // Custom name from metatable + let ud = lua.create_userdata(MyUserdataCustom)?; + assert_eq!(ud.type_name()?, "MyCustomName"); + + // Invalid type name should fallback to "userdata" + let ud = lua.create_userdata(MyUserdataInvalid)?; + assert_eq!(ud.type_name()?.to_str()?, "userdata"); + Ok(()) } #[test] -fn test_userdata_wrapped() -> Result<()> { +fn test_userdata_proxy() -> Result<()> { struct MyUserData(i64); impl UserData for MyUserData { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fn add_fields>(fields: &mut F) { + fields.add_field("static_field", 123); + fields.add_field_method_get("n", |_, this| Ok(this.0)); + } + + fn add_methods>(methods: &mut M) { + methods.add_function("new", |_, n| Ok(Self(n))); + + methods.add_method("plus", |_, this, n: i64| Ok(this.0 + n)); + } + } + + let lua = Lua::new(); + let globals = lua.globals(); + globals.set("MyUserData", lua.create_proxy::()?)?; + + assert!(!globals.get::("MyUserData")?.is_proxy::<()>()); + assert!(globals.get::("MyUserData")?.is_proxy::()); + + lua.load( + r#" + assert(MyUserData.static_field == 123) + local data = MyUserData.new(321) + assert(data.static_field == 123) + assert(data.n == 321) + assert(data:plus(1) == 322) + + -- Error when accessing the proxy object fields and methods that require instance + + local ok = pcall(function() return MyUserData.n end) + assert(not ok) + + ok = pcall(function() return MyUserData:plus(1) end) + assert(not ok) + "#, + ) + .exec() +} + +#[test] +fn test_any_userdata() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::(|reg| { + reg.add_method("get", |_, this, ()| Ok(this.clone())); + reg.add_method_mut("concat", |_, this, s: LuaString| { + this.push_str(&s.to_string_lossy()); + Ok(()) + }); + })?; + + let ud = lua.create_any_userdata("hello".to_string())?; + assert_eq!(&*ud.borrow::()?, "hello"); + + lua.globals().set("ud", ud)?; + lua.load( + r#" + assert(ud:get() == "hello") + ud:concat(", world") + assert(ud:get() == "hello, world") + "#, + ) + .exec() + .unwrap(); + + Ok(()) +} + +#[test] +fn test_any_userdata_wrap() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::(|reg| { + reg.add_method("get", |_, this, ()| Ok(this.clone())); + })?; + + lua.globals().set("s", AnyUserData::wrap("hello".to_string()))?; + lua.load( + r#" + assert(s:get() == "hello") + "#, + ) + .exec() + .unwrap(); + + Ok(()) +} + +#[test] +fn test_userdata_object_like() -> Result<()> { + let lua = Lua::new(); + + #[derive(Clone, Copy)] + struct MyUserData(u32); + + impl UserData for MyUserData { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("n", |_, this| Ok(this.0)); + fields.add_field_method_set("n", |_, this, val| { + this.0 = val; + Ok(()) + }); + } + + fn add_methods>(methods: &mut M) { + methods.add_meta_method(MetaMethod::Call, |_, _this, ()| Ok("called")); + methods.add_method_mut("add", |_, this, x: u32| { + this.0 += x; + Ok(()) + }); + } + } + + let ud = lua.create_userdata(MyUserData(123))?; + + assert_eq!(ud.get::("n")?, 123); + ud.set("n", 321)?; + assert_eq!(ud.get::("n")?, 321); + assert_eq!(ud.get::>("non-existent")?, None); + match ud.set("non-existent", 123) { + Err(Error::RuntimeError(_)) => {} + r => panic!("expected RuntimeError, got {r:?}"), + } + + assert_eq!(ud.call::(())?, "called"); + + ud.call_method::<()>("add", 2)?; + assert_eq!(ud.get::("n")?, 323); + + match ud.call_method::<()>("non_existent", ()) { + Err(Error::RuntimeError(err)) => { + assert!(err.contains("attempt to call a nil value (function 'non_existent')")) + } + r => panic!("expected RuntimeError, got {r:?}"), + } + + assert!(ud.to_string()?.starts_with("MyUserData")); + + Ok(()) +} + +#[test] +fn test_userdata_method_errors() -> Result<()> { + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_methods>(methods: &mut M) { + methods.add_method("get_value", |_, data, ()| Ok(data.0)); + } + } + + let lua = Lua::new(); + + let ud = lua.create_userdata(MyUserData(123))?; + let res = ud.call_function::<()>("get_value", "not a userdata"); + match res { + Err(Error::CallbackError { cause, .. }) => match cause.as_ref() { + Error::BadArgument { + to, + name, + cause: cause2, + .. + } => { + assert_eq!(to.as_deref(), Some("MyUserData.get_value")); + assert_eq!(name.as_deref(), Some("self")); + assert_eq!( + cause2.to_string(), + "error converting Lua string to userdata (expected userdata of type 'MyUserData')" + ); + } + err => panic!("expected BadArgument, got {err:?}"), + }, + r => panic!("expected CallbackError, got {r:?}"), + } + + Ok(()) +} + +#[test] +fn test_userdata_pointer() -> Result<()> { + let lua = Lua::new(); + + let ud1 = lua.create_any_userdata("hello")?; + let ud2 = lua.create_any_userdata("hello")?; + + assert_eq!(ud1.to_pointer(), ud1.clone().to_pointer()); + // Different userdata objects with the same value should have different pointers + assert_ne!(ud1.to_pointer(), ud2.to_pointer()); + + Ok(()) +} + +#[cfg(feature = "macros")] +#[test] +fn test_userdata_derive() -> Result<()> { + let lua = Lua::new(); + + // Simple struct + + #[derive(Clone, Copy, mlua::FromLua)] + struct MyUserData(i32); + + lua.register_userdata_type::(|reg| { + reg.add_function("val", |_, this: MyUserData| Ok(this.0)); + })?; + + lua.globals().set("ud", AnyUserData::wrap(MyUserData(123)))?; + lua.load("assert(ud:val() == 123)").exec()?; + + // More complex struct where generics and where clause + + #[derive(Clone, Copy, mlua::FromLua)] + struct MyUserData2<'a, T: ?Sized>(&'a T) + where + T: Copy; + + lua.register_userdata_type::>(|reg| { + reg.add_function("val", |_, this: MyUserData2<'static, i32>| Ok(*this.0)); + })?; + + lua.globals().set("ud", AnyUserData::wrap(MyUserData2(&321)))?; + lua.load("assert(ud:val() == 321)").exec()?; + + Ok(()) +} + +#[test] +fn test_nested_userdata_gc() -> Result<()> { + let lua = Lua::new(); + + let counter = Arc::new(()); + let arr = vec![lua.create_any_userdata(counter.clone())?]; + let arr_ud = lua.create_any_userdata(arr)?; + + assert_eq!(Arc::strong_count(&counter), 2); + drop(arr_ud); + // On first iteration Lua will destroy the array, on second - userdata + lua.gc_collect()?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&counter), 1); + + Ok(()) +} + +#[cfg(feature = "userdata-wrappers")] +#[test] +fn test_userdata_wrappers() -> Result<()> { + #[derive(Debug)] + struct MyUserData(i64); + + impl UserData for MyUserData { + fn add_fields>(fields: &mut F) { + fields.add_field("static", "constant"); fields.add_field_method_get("data", |_, this| Ok(this.0)); fields.add_field_method_set("data", |_, this, val| { this.0 = val; Ok(()) }) } + + fn add_methods>(methods: &mut M) { + methods.add_method("dbg", |_, this, ()| Ok(format!("{this:?}"))); + } } let lua = Lua::new(); let globals = lua.globals(); + // Rc #[cfg(not(feature = "send"))] { - let ud1 = Rc::new(RefCell::new(MyUserData(1))); - globals.set("rc_refcell_ud", ud1.clone())?; + use std::rc::Rc; + + let ud = Rc::new(MyUserData(1)); + globals.set("ud", ud.clone())?; lua.load( r#" - rc_refcell_ud.data = rc_refcell_ud.data + 1 - assert(rc_refcell_ud.data == 2) + assert(ud.static == "constant") + local ok, err = pcall(function() ud.data = 2 end) + assert( + tostring(err):find("error mutably borrowing userdata") ~= nil, + "expected 'error mutably borrowing userdata', got '" .. tostring(err) .. "'" + ) + assert(ud.data == 1) + assert(ud:dbg(), "MyUserData(1)") "#, ) - .exec()?; - assert_eq!(ud1.borrow().0, 2); - globals.set("rc_refcell_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 1); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + assert!(ud.borrow_mut::>().is_ok()); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 1); + assert!(matches!( + ud.borrow_mut_scoped::(|_| ()), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; - assert_eq!(Rc::strong_count(&ud1), 1); + assert_eq!(Rc::strong_count(&ud), 1); + + // We must be able to mutate userdata when having one reference only + globals.set("ud", ud)?; + lua.load( + r#" + ud.data = 2 + assert(ud.data == 2) + "#, + ) + .exec() + .unwrap(); } - let ud2 = Arc::new(Mutex::new(MyUserData(2))); - globals.set("arc_mutex_ud", ud2.clone())?; - lua.load( - r#" - arc_mutex_ud.data = arc_mutex_ud.data + 1 - assert(arc_mutex_ud.data == 3) - "#, - ) - .exec()?; - #[cfg(not(feature = "parking_lot"))] - assert_eq!(ud2.lock().unwrap().0, 3); - #[cfg(feature = "parking_lot")] - assert_eq!(ud2.lock().0, 3); + // Rc> + #[cfg(not(feature = "send"))] + { + use std::cell::RefCell; + use std::rc::Rc; + + let ud = Rc::new(RefCell::new(MyUserData(2))); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + assert(ud.data == 2) + ud.data = 10 + assert(ud.data == 10) + assert(ud:dbg() == "MyUserData(10)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 10); + assert_eq!(ud.borrow_mut::()?.0, 10); + ud.borrow_mut::()?.0 = 20; + assert_eq!(ud.borrow::()?.0, 20); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 20); + ud.borrow_mut_scoped::(|x| x.0 = 30)?; + assert_eq!(ud.borrow::()?.0, 30); + + // Double (read) borrow is okay + let _borrow = ud.borrow::()?; + assert_eq!(ud.borrow::()?.0, 30); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Rc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; + lua.gc_collect()?; + assert_eq!(ud.0, 30); + drop(ud); + } + + // Arc + { + let ud = Arc::new(MyUserData(3)); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + local ok, err = pcall(function() ud.data = 4 end) + assert( + tostring(err):find("error mutably borrowing userdata") ~= nil, + "expected 'error mutably borrowing userdata', got '" .. tostring(err) .. "'" + ) + assert(ud.data == 3) + assert(ud:dbg() == "MyUserData(3)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 3); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + assert!(ud.borrow_mut::>().is_ok()); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 3); + assert!(matches!( + ud.borrow_mut_scoped::(|_| ()), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&ud), 1); + + // We must be able to mutate userdata when having one reference only + globals.set("ud", ud)?; + lua.load( + r#" + ud.data = 4 + assert(ud.data == 4) + "#, + ) + .exec() + .unwrap(); + } + + // Arc> + { + use std::sync::Mutex; + + let ud = Arc::new(Mutex::new(MyUserData(5))); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + assert(ud.data == 5) + ud.data = 6 + assert(ud.data == 6) + assert(ud:dbg() == "MyUserData(6)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + #[rustfmt::skip] + assert!(matches!(ud.borrow::(), Err(Error::UserDataTypeMismatch))); + #[rustfmt::skip] + assert!(matches!(ud.borrow_mut::(), Err(Error::UserDataTypeMismatch))); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 6); + ud.borrow_mut_scoped::(|x| x.0 = 8)?; + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 8); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&ud), 1); + } + + // Arc> + { + use std::sync::RwLock; + + let ud = Arc::new(RwLock::new(MyUserData(9))); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + assert(ud.data == 9) + ud.data = 10 + assert(ud.data == 10) + assert(ud:dbg() == "MyUserData(10)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + #[rustfmt::skip] + assert!(matches!(ud.borrow::(), Err(Error::UserDataTypeMismatch))); + #[rustfmt::skip] + assert!(matches!(ud.borrow_mut::(), Err(Error::UserDataTypeMismatch))); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 10); + ud.borrow_mut_scoped::(|x| x.0 = 12)?; + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 12); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&ud), 1); + } + + // Arc> + { + use parking_lot::Mutex; + + let ud = Arc::new(Mutex::new(MyUserData(13))); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + assert(ud.data == 13) + ud.data = 14 + assert(ud.data == 14) + assert(ud:dbg() == "MyUserData(14)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 14); + assert_eq!(ud.borrow_mut::()?.0, 14); + ud.borrow_mut::()?.0 = 15; + assert_eq!(ud.borrow::()?.0, 15); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 15); + ud.borrow_mut_scoped::(|x| x.0 = 16)?; + assert_eq!(ud.borrow::()?.0, 16); + + // Double borrow is not allowed + let _borrow = ud.borrow::()?; + assert!(matches!( + ud.borrow::(), + Err(Error::UserDataBorrowError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; + lua.gc_collect()?; + assert_eq!(ud.0, 16); + drop(ud); + } + + // Arc> + { + use parking_lot::RwLock; + + let ud = Arc::new(RwLock::new(MyUserData(17))); + globals.set("ud", ud.clone())?; + lua.load( + r#" + assert(ud.static == "constant") + assert(ud.data == 17) + ud.data = 18 + assert(ud.data == 18) + assert(ud:dbg() == "MyUserData(18)") + "#, + ) + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 18); + assert_eq!(ud.borrow_mut::()?.0, 18); + ud.borrow_mut::()?.0 = 19; + assert_eq!(ud.borrow::()?.0, 19); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 19); + ud.borrow_mut_scoped::(|x| x.0 = 20)?; + assert_eq!(ud.borrow::()?.0, 20); + + // Multiple read borrows are allowed with parking_lot::RwLock + let _borrow1 = ud.borrow::().unwrap(); + // FIXME: does not work due to https://github.com/rust-lang/rust/pull/135634 + // let _borrow2 = ud.borrow::().unwrap(); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Arc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; + lua.gc_collect()?; + assert_eq!(ud.0, 20); + drop(ud); + } - let ud3 = Arc::new(RwLock::new(MyUserData(3))); - globals.set("arc_rwlock_ud", ud3.clone())?; + Ok(()) +} + +#[cfg(feature = "luau")] +#[test] +fn test_userdata_namecall() -> Result<()> { + let lua = Lua::new(); + + struct MyUserData; + + impl UserData for MyUserData { + fn register(registry: &mut UserDataRegistry) { + registry.add_method("method", |_, _, ()| Ok("method called")); + registry.add_field_method_get("field", |_, _| Ok("field value")); + + registry.add_meta_method(MetaMethod::Index, |_, _, key: LuaString| Ok(key)); + + registry.enable_namecall(); + } + } + + let ud = lua.create_userdata(MyUserData)?; + lua.globals().set("ud", &ud)?; lua.load( r#" - arc_rwlock_ud.data = arc_rwlock_ud.data + 1 - assert(arc_rwlock_ud.data == 4) - "#, + assert(ud:method() == "method called") + assert(ud.field == "field value") + assert(ud.dynamic_field == "dynamic_field") + local ok, err = pcall(function() return ud:dynamic_field() end) + assert(tostring(err):find("attempt to call an unknown method 'dynamic_field'") ~= nil) + "#, ) .exec()?; - #[cfg(not(feature = "parking_lot"))] - assert_eq!(ud3.read().unwrap().0, 4); - #[cfg(feature = "parking_lot")] - assert_eq!(ud3.read().0, 4); - - // Test drop - globals.set("arc_mutex_ud", Nil)?; - globals.set("arc_rwlock_ud", Nil)?; - lua.gc_collect()?; - assert_eq!(Arc::strong_count(&ud2), 1); - assert_eq!(Arc::strong_count(&ud3), 1); + + ud.destroy()?; + let err = lua.load("ud:method()").exec().unwrap_err(); + assert!(err.to_string().contains("userdata has been destructed")); + + Ok(()) +} + +#[test] +fn test_userdata_get_path() -> Result<()> { + let lua = Lua::new(); + + struct MyUd; + impl UserData for MyUd { + fn register(registry: &mut UserDataRegistry) { + registry.add_field("value", "userdata_value"); + } + } + + let ud = lua.create_userdata(MyUd)?; + assert_eq!(ud.get_path::(".value")?, "userdata_value"); + + Ok(()) +} + +#[test] +fn test_userdata_owned() -> Result<()> { + #[derive(Debug)] + struct MyUserdata(Arc); + + impl UserData for MyUserdata { + fn register(registry: &mut UserDataRegistry) { + registry.add_method("num", |_, this, ()| Ok(*this.0)); + } + } + + let lua = Lua::new(); + let rc = Arc::new(42); + + // It takes ownership and destructs the Lua userdata + let ud = lua.create_userdata(MyUserdata(rc.clone()))?; + assert_eq!(Arc::strong_count(&rc), 2); + let owned: UserDataOwned = lua.convert(&ud)?; + assert_eq!(*owned.0.0, 42); + drop(owned); + assert_eq!(Arc::strong_count(&rc), 1); + match ud.borrow::() { + Err(Error::UserDataDestructed) => {} + r => panic!("expected UserDataDestructed, got {:?}", r), + } + + // Cannot take while borrowed + let rc = Arc::new(7); + let ud = lua.create_userdata(MyUserdata(rc.clone()))?; + let borrowed = ud.borrow::()?; + match lua.convert::>(&ud) { + Err(Error::UserDataBorrowMutError) => {} + r => panic!("expected UserDataBorrowMutError, got {:?}", r), + } + drop(borrowed); + + // Works as a function parameter + let f = lua.create_function(|_, owned: UserDataOwned| Ok(*owned.0.0))?; + let rc = Arc::new(55); + let ud = lua.create_userdata(MyUserdata(rc.clone()))?; + assert_eq!(f.call::(ud)?, 55); + assert_eq!(Arc::strong_count(&rc), 1); // dropped after call Ok(()) } diff --git a/tests/value.rs b/tests/value.rs index 2f9c1e7b..9ed3b2bf 100644 --- a/tests/value.rs +++ b/tests/value.rs @@ -1,6 +1,11 @@ +use std::collections::HashMap; +use std::os::raw::c_void; use std::ptr; -use mlua::{Lua, Result, Value}; +use mlua::{ + AnyUserData, Error, LightUserData, Lua, MultiValue, Result, UserData, UserDataMethods, UserDataRegistry, + Value, +}; #[test] fn test_value_eq() -> Result<()> { @@ -28,6 +33,7 @@ fn test_value_eq() -> Result<()> { "#, ) .exec()?; + globals.set("null", Value::NULL)?; let table1: Value = globals.get("table1")?; let table2: Value = globals.get("table2")?; @@ -41,25 +47,294 @@ fn test_value_eq() -> Result<()> { let func3: Value = globals.get("func3")?; let thread1: Value = globals.get("thread1")?; let thread2: Value = globals.get("thread2")?; + let null: Value = globals.get("null")?; assert!(table1 != table2); assert!(table1.equals(&table2)?); assert!(string1 == string2); assert!(string1.equals(&string2)?); assert!(num1 == num2); - assert!(num1.equals(num2)?); + assert!(num1.equals(&num2)?); assert!(num1 != num3); assert!(func1 == func2); assert!(func1 != func3); assert!(!func1.equals(&func3)?); assert!(thread1 == thread2); assert!(thread1.equals(&thread2)?); + assert!(null == Value::NULL); assert!(!table1.to_pointer().is_null()); assert!(!ptr::eq(table1.to_pointer(), table2.to_pointer())); - assert!(ptr::eq(string1.to_pointer(), string2.to_pointer())); + assert!(ptr::eq(string1.to_pointer(), string2.to_pointer()) && !string1.to_pointer().is_null()); assert!(ptr::eq(func1.to_pointer(), func2.to_pointer())); assert!(num1.to_pointer().is_null()); Ok(()) } + +#[test] +fn test_multi_value() { + let mut multi_value = MultiValue::new(); + assert_eq!(multi_value.len(), 0); + assert_eq!(multi_value.get(0), None); + + multi_value.push_front(Value::Number(2.)); + multi_value.push_front(Value::Number(1.)); + assert_eq!(multi_value.get(0), Some(&Value::Number(1.))); + assert_eq!(multi_value.get(1), Some(&Value::Number(2.))); + + assert_eq!(multi_value.pop_front(), Some(Value::Number(1.))); + assert_eq!(multi_value[0], Value::Number(2.)); + + multi_value.clear(); + assert!(multi_value.is_empty()); +} + +#[test] +fn test_value_to_pointer() -> Result<()> { + let lua = Lua::new(); + + let globals = lua.globals(); + lua.load( + r#" + table = {} + string = "hello" + num = 1 + func = function() end + thread = coroutine.create(function() end) + "#, + ) + .exec()?; + globals.set("null", Value::NULL)?; + + let table: Value = globals.get("table")?; + let string: Value = globals.get("string")?; + let num: Value = globals.get("num")?; + let func: Value = globals.get("func")?; + let thread: Value = globals.get("thread")?; + let null: Value = globals.get("null")?; + let ud: Value = Value::UserData(lua.create_any_userdata(())?); + + assert!(!table.to_pointer().is_null()); + assert!(!string.to_pointer().is_null()); + assert!(num.to_pointer().is_null()); + assert!(!func.to_pointer().is_null()); + assert!(!thread.to_pointer().is_null()); + assert!(null.to_pointer().is_null()); + assert!(!ud.to_pointer().is_null()); + + Ok(()) +} + +#[test] +fn test_value_to_string() -> Result<()> { + let lua = Lua::new(); + + assert_eq!(Value::Nil.to_string()?, "nil"); + assert_eq!(Value::Nil.type_name(), "nil"); + assert_eq!(Value::Boolean(true).to_string()?, "true"); + assert_eq!(Value::Boolean(true).type_name(), "boolean"); + assert_eq!(Value::NULL.to_string()?, "null"); + assert_eq!(Value::NULL.type_name(), "lightuserdata"); + assert_eq!( + Value::LightUserData(LightUserData(0x1 as *const c_void as *mut _)).to_string()?, + "lightuserdata: 0x1" + ); + assert_eq!(Value::Integer(1).to_string()?, "1"); + assert_eq!(Value::Integer(1).type_name(), "integer"); + assert_eq!(Value::Number(34.59).to_string()?, "34.59"); + assert_eq!(Value::Number(34.59).type_name(), "number"); + #[cfg(all(feature = "luau", not(feature = "luau-vector4")))] + assert_eq!( + Value::Vector(mlua::Vector::new(10.0, 11.1, 12.2)).to_string()?, + "vector(10, 11.1, 12.2)" + ); + #[cfg(all(feature = "luau", not(feature = "luau-vector4")))] + assert_eq!( + Value::Vector(mlua::Vector::new(10.0, 11.1, 12.2)).type_name(), + "vector" + ); + #[cfg(feature = "luau-vector4")] + assert_eq!( + Value::Vector(mlua::Vector::new(10.0, 11.1, 12.2, 13.3)).to_string()?, + "vector(10, 11.1, 12.2, 13.3)" + ); + + let s = Value::String(lua.create_string("hello")?); + assert_eq!(s.to_string()?, "hello"); + assert_eq!(s.type_name(), "string"); + + let table: Value = lua.load("{}").eval()?; + assert!(table.to_string()?.starts_with("table:")); + let table: Value = lua + .load("setmetatable({}, {__tostring = function() return 'test table' end})") + .eval()?; + assert_eq!(table.to_string()?, "test table"); + assert_eq!(table.type_name(), "table"); + + let func: Value = lua.load("function() end").eval()?; + assert!(func.to_string()?.starts_with("function:")); + assert_eq!(func.type_name(), "function"); + + let thread: Value = lua.load("coroutine.create(function() end)").eval()?; + assert!(thread.to_string()?.starts_with("thread:")); + assert_eq!(thread.type_name(), "thread"); + + lua.register_userdata_type::(|reg| { + reg.add_meta_method("__tostring", |_, this, ()| Ok(this.clone())); + })?; + let ud: Value = Value::UserData(lua.create_any_userdata(String::from("string userdata"))?); + assert_eq!(ud.to_string()?, "string userdata"); + assert_eq!(ud.type_name(), "userdata"); + + struct MyUserData; + impl UserData for MyUserData {} + let ud: Value = Value::UserData(lua.create_userdata(MyUserData)?); + assert!(ud.to_string()?.starts_with("MyUserData:")); + + let err = Value::Error(Box::new(Error::runtime("test error"))); + assert_eq!(err.to_string()?, "runtime error: test error"); + assert_eq!(err.type_name(), "error"); + + #[cfg(feature = "luau")] + { + let buf = Value::Buffer(lua.create_buffer(b"hello")?); + assert!(buf.to_string()?.starts_with("buffer:")); + assert_eq!(buf.type_name(), "buffer"); + + // Set `__tostring` metamethod for buffer + let mt = lua.load("{__tostring = buffer.tostring}").eval()?; + lua.set_type_metatable::(mt); + assert_eq!(buf.to_string()?, "hello"); + } + + Ok(()) +} + +#[test] +fn test_debug_format() -> Result<()> { + let lua = Lua::new(); + + lua.register_userdata_type::>(|_| {})?; + let ud = lua + .create_any_userdata::>(HashMap::new()) + .map(Value::UserData)?; + assert!(format!("{ud:#?}").starts_with("HashMap:")); + + struct ToDebugUserData; + impl UserData for ToDebugUserData { + fn register(registry: &mut UserDataRegistry) { + registry.add_meta_method("__tostring", |_, _, ()| Ok("regular-string")); + registry.add_meta_method("__todebugstring", |_, _, ()| Ok("debug-string")); + } + } + let debug_ud = Value::UserData(lua.create_userdata(ToDebugUserData)?); + assert_eq!(debug_ud.to_string()?, "regular-string"); + assert_eq!(format!("{debug_ud:#?}"), "debug-string"); + + struct ToStringUserData; + impl UserData for ToStringUserData { + fn register(registry: &mut UserDataRegistry) { + registry.add_meta_method("__tostring", |_, _, ()| Ok("regular-string")); + } + } + let tostring_only_ud = Value::UserData(lua.create_userdata(ToStringUserData)?); + assert_eq!(format!("{tostring_only_ud:#?}"), "regular-string"); + + // Check that `AnyUsedata` pretty debug format is same as for `Value::UserData` + let any_ud: AnyUserData = lua.create_userdata(ToDebugUserData)?; + let value_ud = Value::UserData(any_ud.clone()); + assert_eq!(format!("{any_ud:#?}"), format!("{value_ud:#?}")); + + Ok(()) +} + +#[test] +fn test_value_conversions() -> Result<()> { + let lua = Lua::new(); + + assert!(Value::Nil.is_nil()); + assert!(!Value::NULL.is_nil()); + assert!(Value::NULL.is_null()); + assert!(Value::NULL.is_light_userdata()); + assert!(Value::NULL.as_light_userdata() == Some(LightUserData(ptr::null_mut()))); + assert!(Value::Boolean(true).is_boolean()); + assert_eq!(Value::Boolean(false).as_boolean(), Some(false)); + assert!(Value::Integer(1).is_integer()); + assert_eq!(Value::Integer(1).as_integer(), Some(1)); + assert_eq!(Value::Integer(1).as_i32(), Some(1i32)); + assert_eq!(Value::Integer(1).as_u32(), Some(1u32)); + assert_eq!(Value::Integer(1).as_i64(), Some(1i64)); + assert_eq!(Value::Integer(1).as_u64(), Some(1u64)); + #[cfg(any(feature = "lua55", feature = "lua54", feature = "lua53"))] + { + assert_eq!(Value::Integer(mlua::Integer::MAX).as_i32(), None); + assert_eq!(Value::Integer(mlua::Integer::MAX).as_u32(), None); + } + assert_eq!(Value::Integer(1).as_isize(), Some(1isize)); + assert_eq!(Value::Integer(1).as_usize(), Some(1usize)); + assert!(Value::Number(1.23).is_number()); + assert_eq!(Value::Number(1.23).as_number(), Some(1.23)); + assert_eq!(Value::Number(1.23).as_f32(), Some(1.23f32)); + assert_eq!(Value::Number(1.23).as_f64(), Some(1.23f64)); + assert!(Value::String(lua.create_string("hello")?).is_string()); + assert_eq!( + Value::String(lua.create_string("hello")?).as_string().unwrap(), + "hello" + ); + assert_eq!(Value::String(lua.create_string("hello")?).to_string()?, "hello"); + assert!(Value::Table(lua.create_table()?).is_table()); + assert!(Value::Table(lua.create_table()?).as_table().is_some()); + assert!(Value::Function(lua.create_function(|_, ()| Ok(())).unwrap()).is_function()); + assert!( + Value::Function(lua.create_function(|_, ()| Ok(())).unwrap()) + .as_function() + .is_some() + ); + assert!(Value::Thread(lua.create_thread(lua.load("function() end").eval()?)?).is_thread()); + assert!( + Value::Thread(lua.create_thread(lua.load("function() end").eval()?)?) + .as_thread() + .is_some() + ); + assert!(Value::UserData(lua.create_any_userdata("hello")?).is_userdata()); + assert_eq!( + Value::UserData(lua.create_any_userdata("hello")?) + .as_userdata() + .and_then(|ud| ud.borrow::<&str>().ok()) + .as_deref(), + Some(&"hello") + ); + + assert!(Value::Error(Box::new(Error::runtime("some error"))).is_error()); + assert_eq!( + (Value::Error(Box::new(Error::runtime("some error"))).as_error()) + .unwrap() + .to_string(), + "runtime error: some error" + ); + + Ok(()) +} + +#[test] +fn test_value_exhaustive_match() { + match Value::Nil { + Value::Nil => {} + Value::Boolean(_) => {} + Value::LightUserData(_) => {} + Value::Integer(_) => {} + Value::Number(_) => {} + #[cfg(feature = "luau")] + Value::Vector(_) => {} + Value::String(_) => {} + Value::Table(_) => {} + Value::Function(_) => {} + Value::Thread(_) => {} + Value::UserData(_) => {} + #[cfg(feature = "luau")] + Value::Buffer(_) => {} + Value::Error(_) => {} + Value::Other(_) => {} + } +} diff --git a/typos.toml b/typos.toml new file mode 100644 index 00000000..7a8fc528 --- /dev/null +++ b/typos.toml @@ -0,0 +1,14 @@ +[default] +extend-ignore-identifiers-re = ["2nd", "ser"] +extend-ignore-re = [ + # Custom ignore regex patterns: https://github.com/crate-ci/typos/blob/master/docs/reference.md#defaultextend-ignore-re + ".*(?:#|--|//|/*).*(?:spellchecker|typos):\\s?ignore[^\\n]*\\n", + ".*(?:spellchecker|typos):\\s?ignore-next-line[^\\n]*\\n[^\\n]*", +] + +[files] +extend-exclude = ["tests/compile/*.stderr"] + +[default.extend-words] +thr = "thr" +aas = "aas"